[
  {
    "path": ".gitignore",
    "content": ".DS_Store\n"
  },
  {
    "path": "README.md",
    "content": "# CAE: Context AutoEncoder for Self-Supervised Representation Learning \n\n<p align=\"center\">\n  <img src='furnace/CAE.png'>\n</p>\n\nThis is a PyTorch implementation of [CAE: Context AutoEncoder for Self-Supervised Representation Learning](https://arxiv.org/abs/2202.03026).\n\n## Highlights\n\n- State-of-the-art MIM performance. Results in the paper are successfully reproduced.\n\n## Installation\n\nClone the repo and install required packages.\n```bash\npip install -r requirements.txt\n\n# install apex\ngit clone https://github.com/NVIDIA/apex\ncd apex\npip install -v --disable-pip-version-check --no-cache-dir --global-option=\"--cpp_ext\" --global-option=\"--cuda_ext\" ./\n```\n\n## Data Preparation\nFirst, download ImageNet-1k from http://image-net.org/.\n\nThe directory structure is the standard layout of torchvision's datasets.ImageFolder. The training and validation data are expected to be in the train/ folder and val folder, respectively:\n\n```\n/path/to/imagenet/\n  train/\n    class1/\n      img1.jpeg\n    class2/\n      img2.jpeg\n  val/\n    class1/\n      img3.jpeg\n    class/2\n      img4.jpeg\n```\n\nSecond, download the pretrained tokenizer.\n\n```bash\nTOKENIZER_PATH=/path/to/save/dall_e_tokenizer_weight\nmkdir -p $TOKENIZER_PATH\nwget -o $TOKENIZER_PATH/encoder.pkl https://cdn.openai.com/dall-e/encoder.pkl\nwget -o $TOKENIZER_PATH/decoder.pkl https://cdn.openai.com/dall-e/decoder.pkl\n```\n\n\n## Pretraining\n\nHere is an example that pretrains CAE-base on ImageNet-1K with 32 GPUs. Please see [scripts/cae_base_800e.sh](scripts/cae_base_800e.sh) for complete script.\n```bash\nOMP_NUM_THREADS=1 $PYTHON -m torch.distributed.launch \\\n  --nproc_per_node=8 \\\n  tools/run_pretraining.py \\\n  --data_path ${DATA_PATH} \\\n  --output_dir ${OUTPUT_DIR} \\\n  --model cae_base_patch16_224_8k_vocab --discrete_vae_weight_path ${TOKENIZER_PATH} \\\n  --batch_size 64 --lr 1.5e-3 --warmup_epochs 20 --epochs 800 \\\n  --clip_grad 3.0 --layer_scale_init_value 0.1 \\\n  --imagenet_default_mean_and_std \\\n  --color_jitter 0 \\\n  --drop_path 0.1 \\\n  --sincos_pos_emb \\\n  --mask_generator block \\\n  --num_mask_patches 98 \\\n  --decoder_layer_scale_init_value 0.1 \\\n  --no_auto_resume \\\n  --save_ckpt_freq 100 \\\n  --exp_name $my_name \\\n  --regressor_depth 4 \\\n  --decoder_depth 4 \\\n  --align_loss_weight 2\n```\n- `--num_mask_patches`: number of the input patches need be masked. \n- `--batch_size`: batch size per GPU.\n- Effective batch size = `number of GPUs` * `--batch_size`. So in the above example, the effective batch size is `64*32 = 2048`.\n- `--lr`: learning rate.\n- `--warmup_epochs`: learning rate warmup epochs. Warm up [10, 20, 40] epochs for [300, 800, 1600] pretrain epochs respectively.\n- `--epochs`: total pretraining epochs.\n- `--clip_grad`: clip gradient norm.\n- `--drop_path`: stochastic depth rate.\n- `--imagenet_default_mean_and_std`: enable this for ImageNet-1k pretraining, i.e., `(0.485, 0.456, 0.406)` for mean and `(0.229, 0.224, 0.225)` for std. For other pretraining data, use `(0.5, 0.5, 0.5)` for mean and `(0.5, 0.5, 0.5)` for std by default.\n- `--layer_scale_init_value`: 0.1 for base, 1e-5 for large, set 0 to disable layerscale. We set `--decoder_layer_scale_init_value` the same as this.\n- `--sincos_pos_emb`: adopt sin-cos positional embedding during pretraining.\n- `--regressor_depth`: length of the regressor.\n- `--decoder_depth`: length of the decoder.\n- `--align_loss_weight`: weight for alignment loss. 2 by default.\n\nWarmup epochs for 300/800/1600 epochs pretraining are 10/20/40.\n\nFor CAE-large, please refer to [scripts/cae_large_1600e.sh](scripts/cae_large_1600e.sh). \n\n\n## Results\nHere provides the results of CAE-base/CAE-large for these evaluation tasks:\n- Linear probing\n- Attentive probing\n- Fine-tuning\n- Semantic segmentation\n- Object detection and instance segmentation\n\nPretrained weights and logs are available ([Google Drive](https://drive.google.com/drive/folders/1wwhg7nj2GQuU9uthVuQLkEEXEjx90G7g?usp=sharing), [Baidu Cloud [Code: 4kil]](https://pan.baidu.com/s/15eZGoI72iLupLrOHqmOM9w)). *: from CAE paper.\n\n| Model      | Pretraining data | #Epoch | Linear | Attentive | Fine-tuning | ADE Seg | COCO Det | COCO InstSeg |\n| ---------- | ---------------- | ------ | ------ | --------- | ----------- | ------- | -------- | ------------ |\n| MAE-base*  | ImageNet-1K      | 1600   | 67.8   | 74.2      | 83.6        | 48.1    | 48.4     | 42.6         |\n| MAE-large* | ImageNet-1K      | 1600   | 76.0   | 78.8      | 86.0        | 53.6    | 54.0     | 47.1         |\n| CAE-base   | ImageNet-1K      | 300    | 64.5   | 74.0      | 83.6        | 48.1    | 48.3     | 42.7         |\n| CAE-base   | ImageNet-1K      | 800    | 68.9   | 75.9      | 83.8        | 49.7    | 49.9     | 43.9         |\n| CAE-base   | ImageNet-1K      | 1600   | 70.3   | 77.2      | 83.9        | 50.3    | 50.3     | 44.2         |\n| CAE-large  | ImageNet-1K      | 1600   | 77.8   | 81.2      | 86.2        | 54.9    | 54.5     | 47.5         |\n\n\n### Linear Probing\n- Please refer to [scripts/cae_base_800e.sh](scripts/cae_base_800e.sh) (32 GPUs).  \n- For CAE-large, just replace `--model cae_base_patch16_224` with `--model cae_large_patch16_224`.\n\n### Attentive Probing\n\n- Please refer to [scripts/cae_base_800e.sh](scripts/cae_base_800e.sh) (32 GPUs). \n- For CAE-large, just replace `--model cae_base_patch16_224` with `--model cae_large_patch16_224`.\n\n### Fine-tuning\n- Please refer to [scripts/cae_base_finetune.sh](scripts/cae_base_finetune.sh) (32 GPUs). \n- For CAE-large, please refer to [scripts/cae_large_finetune.sh](scripts/cae_large_finetune.sh) (32 GPUs).\n\n### Segmentation & Detection\n- Please refer to [downstream_tasks](./downstream_tasks) dir to get started.\n\n## Acknowledgement\n\nThis repository is built using the [BEiT](https://github.com/microsoft/unilm/edit/master/beit) and [MMSelfSup](https://github.com/open-mmlab/mmselfsup), thanks for their open-source code! Thanks also to the CAE authors for their excellent work!\n\n## Citation\n```bibtex\n@article{ContextAutoencoder2022,\n  title={Context Autoencoder for Self-Supervised Representation Learning},\n  author={Chen, Xiaokang and Ding, Mingyu and Wang, Xiaodi and Xin, Ying and Mo, Shentong and Wang, Yunhao and Han, Shumin and Luo, Ping and Zeng, Gang and Wang, Jingdong},\n  journal={arXiv preprint arXiv:2202.03026},\n  year={2022}\n}\n```\n"
  },
  {
    "path": "dall_e/__init__.py",
    "content": "import io, requests\nimport torch\nimport torch.nn as nn\n\nfrom dall_e.encoder import Encoder\nfrom dall_e.decoder import Decoder\nfrom dall_e.utils   import map_pixels, unmap_pixels\n\ndef load_model(path: str, device: torch.device = None) -> nn.Module:\n    if path.startswith('http://') or path.startswith('https://'):\n        resp = requests.get(path)\n        resp.raise_for_status()\n            \n        with io.BytesIO(resp.content) as buf:\n            return torch.load(buf, map_location=device)\n    else:\n        with open(path, 'rb') as f:\n            return torch.load(f, map_location=device)\n"
  },
  {
    "path": "dall_e/decoder.py",
    "content": "import attr\nimport numpy as np\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom collections  import OrderedDict\nfrom functools    import partial\nfrom dall_e.utils import Conv2d\n\n@attr.s(eq=False, repr=False)\nclass DecoderBlock(nn.Module):\n\tn_in:     int = attr.ib(validator=lambda i, a, x: x >= 1)\n\tn_out:    int = attr.ib(validator=lambda i, a, x: x >= 1 and x % 4 ==0)\n\tn_layers: int = attr.ib(validator=lambda i, a, x: x >= 1)\n\n\tdevice:        torch.device = attr.ib(default=None)\n\trequires_grad: bool         = attr.ib(default=False)\n\n\tdef __attrs_post_init__(self) -> None:\n\t\tsuper().__init__()\n\t\tself.n_hid = self.n_out // 4\n\t\tself.post_gain = 1 / (self.n_layers ** 2)\n\n\t\tmake_conv     = partial(Conv2d, device=self.device, requires_grad=self.requires_grad)\n\t\tself.id_path  = make_conv(self.n_in, self.n_out, 1) if self.n_in != self.n_out else nn.Identity()\n\t\tself.res_path = nn.Sequential(OrderedDict([\n\t\t\t\t('relu_1', nn.ReLU()),\n\t\t\t\t('conv_1', make_conv(self.n_in,  self.n_hid, 1)),\n\t\t\t\t('relu_2', nn.ReLU()),\n\t\t\t\t('conv_2', make_conv(self.n_hid, self.n_hid, 3)),\n\t\t\t\t('relu_3', nn.ReLU()),\n\t\t\t\t('conv_3', make_conv(self.n_hid, self.n_hid, 3)),\n\t\t\t\t('relu_4', nn.ReLU()),\n\t\t\t\t('conv_4', make_conv(self.n_hid, self.n_out, 3)),]))\n\n\tdef forward(self, x: torch.Tensor) -> torch.Tensor:\n\t\treturn self.id_path(x) + self.post_gain * self.res_path(x)\n\n@attr.s(eq=False, repr=False)\nclass Decoder(nn.Module):\n\tgroup_count:     int = 4\n\tn_init:          int = attr.ib(default=128,  validator=lambda i, a, x: x >= 8)\n\tn_hid:           int = attr.ib(default=256,  validator=lambda i, a, x: x >= 64)\n\tn_blk_per_group: int = attr.ib(default=2,    validator=lambda i, a, x: x >= 1)\n\toutput_channels: int = attr.ib(default=3,    validator=lambda i, a, x: x >= 1)\n\tvocab_size:      int = attr.ib(default=8192, validator=lambda i, a, x: x >= 512)\n\n\tdevice:              torch.device = attr.ib(default=torch.device('cpu'))\n\trequires_grad:       bool         = attr.ib(default=False)\n\tuse_mixed_precision: bool         = attr.ib(default=True)\n\n\tdef __attrs_post_init__(self) -> None:\n\t\tsuper().__init__()\n\n\t\tblk_range  = range(self.n_blk_per_group)\n\t\tn_layers   = self.group_count * self.n_blk_per_group\n\t\tmake_conv  = partial(Conv2d, device=self.device, requires_grad=self.requires_grad)\n\t\tmake_blk   = partial(DecoderBlock, n_layers=n_layers, device=self.device,\n\t\t\t\trequires_grad=self.requires_grad)\n\n\t\tself.blocks = nn.Sequential(OrderedDict([\n\t\t\t('input', make_conv(self.vocab_size, self.n_init, 1, use_float16=False)),\n\t\t\t('group_1', nn.Sequential(OrderedDict([\n\t\t\t\t*[(f'block_{i + 1}', make_blk(self.n_init if i == 0 else 8 * self.n_hid, 8 * self.n_hid)) for i in blk_range],\n\t\t\t\t('upsample', nn.Upsample(scale_factor=2, mode='nearest')),\n\t\t\t]))),\n\t\t\t('group_2', nn.Sequential(OrderedDict([\n\t\t\t\t*[(f'block_{i + 1}', make_blk(8 * self.n_hid if i == 0 else 4 * self.n_hid, 4 * self.n_hid)) for i in blk_range],\n\t\t\t\t('upsample', nn.Upsample(scale_factor=2, mode='nearest')),\n\t\t\t]))),\n\t\t\t('group_3', nn.Sequential(OrderedDict([\n\t\t\t\t*[(f'block_{i + 1}', make_blk(4 * self.n_hid if i == 0 else 2 * self.n_hid, 2 * self.n_hid)) for i in blk_range],\n\t\t\t\t('upsample', nn.Upsample(scale_factor=2, mode='nearest')),\n\t\t\t]))),\n\t\t\t('group_4', nn.Sequential(OrderedDict([\n\t\t\t\t*[(f'block_{i + 1}', make_blk(2 * self.n_hid if i == 0 else 1 * self.n_hid, 1 * self.n_hid)) for i in blk_range],\n\t\t\t]))),\n\t\t\t('output', nn.Sequential(OrderedDict([\n\t\t\t\t('relu', nn.ReLU()),\n\t\t\t\t('conv', make_conv(1 * self.n_hid, 2 * self.output_channels, 1)),\n\t\t\t]))),\n\t\t]))\n\n\tdef forward(self, x: torch.Tensor) -> torch.Tensor:\n\t\tif len(x.shape) != 4:\n\t\t\traise ValueError(f'input shape {x.shape} is not 4d')\n\t\tif x.shape[1] != self.vocab_size:\n\t\t\traise ValueError(f'input has {x.shape[1]} channels but model built for {self.vocab_size}')\n\t\tif x.dtype != torch.float32:\n\t\t\traise ValueError('input must have dtype torch.float32')\n\n\t\treturn self.blocks(x)\n"
  },
  {
    "path": "dall_e/encoder.py",
    "content": "import attr\nimport numpy as np\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom collections  import OrderedDict\nfrom functools    import partial\nfrom dall_e.utils import Conv2d\n\n@attr.s(eq=False, repr=False)\nclass EncoderBlock(nn.Module):\n\tn_in:     int = attr.ib(validator=lambda i, a, x: x >= 1)\n\tn_out:    int = attr.ib(validator=lambda i, a, x: x >= 1 and x % 4 ==0)\n\tn_layers: int = attr.ib(validator=lambda i, a, x: x >= 1)\n\n\tdevice:        torch.device = attr.ib(default=None)\n\trequires_grad: bool         = attr.ib(default=False)\n\n\tdef __attrs_post_init__(self) -> None:\n\t\tsuper().__init__()\n\t\tself.n_hid = self.n_out // 4\n\t\tself.post_gain = 1 / (self.n_layers ** 2)\n\n\t\tmake_conv     = partial(Conv2d, device=self.device, requires_grad=self.requires_grad)\n\t\tself.id_path  = make_conv(self.n_in, self.n_out, 1) if self.n_in != self.n_out else nn.Identity()\n\t\tself.res_path = nn.Sequential(OrderedDict([\n\t\t\t\t('relu_1', nn.ReLU()),\n\t\t\t\t('conv_1', make_conv(self.n_in,  self.n_hid, 3)),\n\t\t\t\t('relu_2', nn.ReLU()),\n\t\t\t\t('conv_2', make_conv(self.n_hid, self.n_hid, 3)),\n\t\t\t\t('relu_3', nn.ReLU()),\n\t\t\t\t('conv_3', make_conv(self.n_hid, self.n_hid, 3)),\n\t\t\t\t('relu_4', nn.ReLU()),\n\t\t\t\t('conv_4', make_conv(self.n_hid, self.n_out, 1)),]))\n\n\tdef forward(self, x: torch.Tensor) -> torch.Tensor:\n\t\treturn self.id_path(x) + self.post_gain * self.res_path(x)\n\n@attr.s(eq=False, repr=False)\nclass Encoder(nn.Module):\n\tgroup_count:     int = 4\n\tn_hid:           int = attr.ib(default=256,  validator=lambda i, a, x: x >= 64)\n\tn_blk_per_group: int = attr.ib(default=2,    validator=lambda i, a, x: x >= 1)\n\tinput_channels:  int = attr.ib(default=3,    validator=lambda i, a, x: x >= 1)\n\tvocab_size:      int = attr.ib(default=8192, validator=lambda i, a, x: x >= 512)\n\n\tdevice:              torch.device = attr.ib(default=torch.device('cpu'))\n\trequires_grad:       bool         = attr.ib(default=False)\n\tuse_mixed_precision: bool         = attr.ib(default=True)\n\n\tdef __attrs_post_init__(self) -> None:\n\t\tsuper().__init__()\n\n\t\tblk_range  = range(self.n_blk_per_group)\n\t\tn_layers   = self.group_count * self.n_blk_per_group\n\t\tmake_conv  = partial(Conv2d, device=self.device, requires_grad=self.requires_grad)\n\t\tmake_blk   = partial(EncoderBlock, n_layers=n_layers, device=self.device,\n\t\t\t\trequires_grad=self.requires_grad)\n\n\t\tself.blocks = nn.Sequential(OrderedDict([\n\t\t\t('input', make_conv(self.input_channels, 1 * self.n_hid, 7)),\n\t\t\t('group_1', nn.Sequential(OrderedDict([\n\t\t\t\t*[(f'block_{i + 1}', make_blk(1 * self.n_hid, 1 * self.n_hid)) for i in blk_range],\n\t\t\t\t('pool', nn.MaxPool2d(kernel_size=2)),\n\t\t\t]))),\n\t\t\t('group_2', nn.Sequential(OrderedDict([\n\t\t\t\t*[(f'block_{i + 1}', make_blk(1 * self.n_hid if i == 0 else 2 * self.n_hid, 2 * self.n_hid)) for i in blk_range],\n\t\t\t\t('pool', nn.MaxPool2d(kernel_size=2)),\n\t\t\t]))),\n\t\t\t('group_3', nn.Sequential(OrderedDict([\n\t\t\t\t*[(f'block_{i + 1}', make_blk(2 * self.n_hid if i == 0 else 4 * self.n_hid, 4 * self.n_hid)) for i in blk_range],\n\t\t\t\t('pool', nn.MaxPool2d(kernel_size=2)),\n\t\t\t]))),\n\t\t\t('group_4', nn.Sequential(OrderedDict([\n\t\t\t\t*[(f'block_{i + 1}', make_blk(4 * self.n_hid if i == 0 else 8 * self.n_hid, 8 * self.n_hid)) for i in blk_range],\n\t\t\t]))),\n\t\t\t('output', nn.Sequential(OrderedDict([\n\t\t\t\t('relu', nn.ReLU()),\n\t\t\t\t('conv', make_conv(8 * self.n_hid, self.vocab_size, 1, use_float16=False)),\n\t\t\t]))),\n\t\t]))\n\n\tdef forward(self, x: torch.Tensor) -> torch.Tensor:\n\t\tif len(x.shape) != 4:\n\t\t\traise ValueError(f'input shape {x.shape} is not 4d')\n\t\tif x.shape[1] != self.input_channels:\n\t\t\traise ValueError(f'input has {x.shape[1]} channels but model built for {self.input_channels}')\n\t\tif x.dtype != torch.float32:\n\t\t\traise ValueError('input must have dtype torch.float32')\n\n\t\treturn self.blocks(x)\n"
  },
  {
    "path": "dall_e/utils.py",
    "content": "import attr\nimport math\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nlogit_laplace_eps: float = 0.1\n\n@attr.s(eq=False)\nclass Conv2d(nn.Module):\n\tn_in:  int = attr.ib(validator=lambda i, a, x: x >= 1)\n\tn_out: int = attr.ib(validator=lambda i, a, x: x >= 1)\n\tkw:    int = attr.ib(validator=lambda i, a, x: x >= 1 and x % 2 == 1)\n\n\tuse_float16:   bool         = attr.ib(default=True)\n\tdevice:        torch.device = attr.ib(default=torch.device('cpu'))\n\trequires_grad: bool         = attr.ib(default=False)\n\n\tdef __attrs_post_init__(self) -> None:\n\t\tsuper().__init__()\n\n\t\tw = torch.empty((self.n_out, self.n_in, self.kw, self.kw), dtype=torch.float32,\n\t\t\tdevice=self.device, requires_grad=self.requires_grad)\n\t\tw.normal_(std=1 / math.sqrt(self.n_in * self.kw ** 2))\n\n\t\tb = torch.zeros((self.n_out,), dtype=torch.float32, device=self.device,\n\t\t\trequires_grad=self.requires_grad)\n\t\tself.w, self.b = nn.Parameter(w), nn.Parameter(b)\n\n\tdef forward(self, x: torch.Tensor) -> torch.Tensor:\n\t\tif self.use_float16 and 'cuda' in self.w.device.type:\n\t\t\tif x.dtype != torch.float16:\n\t\t\t\tx = x.half()\n\n\t\t\tw, b = self.w.half(), self.b.half()\n\t\telse:\n\t\t\tif x.dtype != torch.float32:\n\t\t\t\tx = x.float()\n\n\t\t\tw, b = self.w, self.b\n\n\t\treturn F.conv2d(x, w, b, padding=(self.kw - 1) // 2)\n\ndef map_pixels(x: torch.Tensor) -> torch.Tensor:\n\tif x.dtype != torch.float:\n\t\traise ValueError('expected input to have type float')\n\n\treturn (1 - 2 * logit_laplace_eps) * x + logit_laplace_eps\n\ndef unmap_pixels(x: torch.Tensor) -> torch.Tensor:\n\tif len(x.shape) != 4:\n\t\traise ValueError('expected input to be 4d')\n\tif x.dtype != torch.float:\n\t\traise ValueError('expected input to have type float')\n\n\treturn torch.clamp((x - logit_laplace_eps) / (1 - 2 * logit_laplace_eps), 0, 1)\n"
  },
  {
    "path": "downstream_tasks/detection/README.md",
    "content": "\n# COCO Detection and Instance segmentation with CAE\n\n# Installation\n\nPlease install [PyTorch](https://pytorch.org/). This codebase has been developed with python version 3.6, PyTorch version 1.7.1, CUDA 11.0 and torchvision 0.8.2. To get the full dependencies, please run:\n\n```bash\npip3 install -f https://download.openmmlab.com/mmcv/dist/cu110/torch1.7.1/index.html mmcv-full==1.3.9\npip3 install pytest-runner scipy tensorboardX faiss-gpu==1.6.1 tqdm lmdb sklearn pyarrow==2.0.0 timm DALL-E munkres six einops\n\n# install apex\npip3 install git+https://github.com/NVIDIA/apex \\\n    --no-cache-dir --global-option=\"--cpp_ext\" --global-option=\"--cuda_ext\"\n\n# install mmdetection for object detection & instance segmentation\ngit clone https://github.com/SwinTransformer/Swin-Transformer-Object-Detection\ncd Swin-Transformer-Object-Detection\npip3 install -r requirements/build.txt\npip3 install -v -e .\ncd ..\n```\n\n\n## Fine-tuning with Mask R-CNN\n#### We use 16 GPUs for these experiments, $NNODES = 2.\n\n- To train ViT-B/16 with Mask R-CNN as the task layer, run:\n```bash\npython -m torch.distributed.launch --nproc_per_node=8 \\\n    --nnodes=$NNODES \\\n    --node_rank=$RANK \\\n    --master_addr=$ADDRESS \\\n    --master_port=$PORT \\\n    evaluation/object_detection/train.py evaluation/object_detection/configs/mask_rcnn/vit_base_giou_4conv1f_coco_maskrcnn_1x_cae_sincos_init0.1_lr00003.py \\\n    --launcher pytorch \\\n    --work-dir $OUTPUT_DIR \\\n    --no-validate \\\n    --deterministic \\\n    --cfg-options model.backbone.use_checkpoint=True \\\n    model.pretrained=$PRETRAINED \\\n    ${@:6}\n```\n\n- To train ViT-L/16 with Mask R-CNN as the task layer, run:\n```bash\npython -m torch.distributed.launch --nproc_per_node=8 \\\n    --nnodes=$NNODES \\\n    --node_rank=$RANK \\\n    --master_addr=$ADDRESS \\\n    --master_port=$PORT \\\n    evaluation/object_detection/train.py evaluation/object_detection/configs/mask_rcnn/vit_large_giou_4conv1f_coco_maskrcnn_1x_cae_sincos_init0.1_lr00002_lrdr0.85_dp0.2.py \\\n    --launcher pytorch \\\n    --work-dir $OUTPUT_DIR \\\n    --no-validate \\\n    --deterministic \\\n    --cfg-options model.backbone.use_checkpoint=True \\\n\tmodel.pretrained=$PRETRAINED \\\n    ${@:6}\n```\n\n- To evaluate Mask R-CNN, run:\n```bash\npython -m torch.distributed.launch --nproc_per_node=8 \\\n    evaluation/object_detection/test.py \\\n    $CONFIG \\\n    $MODEL \\\n    --launcher pytorch \\\n    --eval bbox segm \\\n    --cfg-options model.backbone.use_checkpoint=True \\\n    ${@:6}\n```\n\n## Results (pretrined models are trained on ImageNet-1K without label)\n| Backbone | #Pretrained Epoch | Object Det | Instance Seg |\n| -------- | ----------------- | ---------- | ------------ |\n| ViT-B    | 300               | 48.3       | 42.7         |\n| ViT-B    | 800               | 49.9       | 43.9         |\n| ViT-B    | 1600              | 50.3       | 44.2         |\n| ViT-L    | 1600              | 54.5       | 47.5         |\n\n\n## Acknowledgement\n\nThis repository is built using the [IBOT repository](https://github.com/bytedance/ibot). Thanks for their open-source code!\n"
  },
  {
    "path": "downstream_tasks/detection/evaluation/object_detection/configs/_base_/datasets/coco_instance.py",
    "content": "dataset_type = 'CocoDataset'\ndata_root = '/path/to/coco/'\nimg_norm_cfg = dict(\n    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)\ntrain_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(type='LoadAnnotations', with_bbox=True, with_mask=True),\n    dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),\n    dict(type='RandomFlip', flip_ratio=0.5),\n    dict(type='Normalize', **img_norm_cfg),\n    dict(type='Pad', size_divisor=32),\n    dict(type='DefaultFormatBundle'),\n    dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),\n]\ntest_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(\n        type='MultiScaleFlipAug',\n        img_scale=(1333, 800),\n        flip=False,\n        transforms=[\n            dict(type='Resize', keep_ratio=True),\n            dict(type='RandomFlip'),\n            dict(type='Normalize', **img_norm_cfg),\n            dict(type='Pad', size_divisor=32),\n            dict(type='ImageToTensor', keys=['img']),\n            dict(type='Collect', keys=['img']),\n        ])\n]\ndata = dict(\n    samples_per_gpu=2,\n    workers_per_gpu=2,\n    train=dict(\n        type=dataset_type,\n        ann_file=data_root + 'annotations/instances_train2017.json',\n        img_prefix=data_root + 'train2017/',\n        pipeline=train_pipeline),\n    val=dict(\n        type=dataset_type,\n        ann_file=data_root + 'annotations/instances_val2017.json',\n        img_prefix=data_root + 'val2017/',\n        pipeline=test_pipeline),\n    test=dict(\n        type=dataset_type,\n        ann_file=data_root + 'annotations/instances_val2017.json',\n        img_prefix=data_root + 'val2017/',\n        pipeline=test_pipeline))\nevaluation = dict(metric=['bbox', 'segm'])\n"
  },
  {
    "path": "downstream_tasks/detection/evaluation/object_detection/configs/_base_/default_runtime.py",
    "content": "checkpoint_config = dict(interval=1)\n# yapf:disable\nlog_config = dict(\n    interval=50,\n    hooks=[\n        dict(type='TextLoggerHook'),\n        # dict(type='TensorboardLoggerHook')\n    ])\n# yapf:enable\ncustom_hooks = [dict(type='NumClassCheckHook')]\n\ndist_params = dict(backend='nccl')\nlog_level = 'INFO'\nload_from = None\nresume_from = None\nworkflow = [('train', 1)]"
  },
  {
    "path": "downstream_tasks/detection/evaluation/object_detection/configs/_base_/models/cascade_mask_rcnn_r50_fpn.py",
    "content": "ettings\nmodel = dict(\n    type='CascadeRCNN',\n    backbone=dict(\n        type='ResNet',\n        depth=50,\n        num_stages=4,\n        out_indices=(0, 1, 2, 3),\n        frozen_stages=1,\n        norm_cfg=dict(type='BN', requires_grad=True),\n        norm_eval=True,\n        style='pytorch',\n        init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),\n    neck=dict(\n        type='FPN',\n        in_channels=[256, 512, 1024, 2048],\n        out_channels=256,\n        num_outs=5),\n    rpn_head=dict(\n        type='RPNHead',\n        in_channels=256,\n        feat_channels=256,\n        anchor_generator=dict(\n            type='AnchorGenerator',\n            scales=[8],\n            ratios=[0.5, 1.0, 2.0],\n            strides=[4, 8, 16, 32, 64]),\n        bbox_coder=dict(\n            type='DeltaXYWHBBoxCoder',\n            target_means=[.0, .0, .0, .0],\n            target_stds=[1.0, 1.0, 1.0, 1.0]),\n        loss_cls=dict(\n            type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),\n        loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0)),\n    roi_head=dict(\n        type='CascadeRoIHead',\n        num_stages=3,\n        stage_loss_weights=[1, 0.5, 0.25],\n        bbox_roi_extractor=dict(\n            type='SingleRoIExtractor',\n            roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),\n            out_channels=256,\n            featmap_strides=[4, 8, 16, 32]),\n        bbox_head=[\n            dict(\n                type='Shared2FCBBoxHead',\n                in_channels=256,\n                fc_out_channels=1024,\n                roi_feat_size=7,\n                num_classes=80,\n                bbox_coder=dict(\n                    type='DeltaXYWHBBoxCoder',\n                    target_means=[0., 0., 0., 0.],\n                    target_stds=[0.1, 0.1, 0.2, 0.2]),\n                reg_class_agnostic=True,\n                loss_cls=dict(\n                    type='CrossEntropyLoss',\n                    use_sigmoid=False,\n                    loss_weight=1.0),\n                loss_bbox=dict(type='SmoothL1Loss', beta=1.0,\n                               loss_weight=1.0)),\n            dict(\n                type='Shared2FCBBoxHead',\n                in_channels=256,\n                fc_out_channels=1024,\n                roi_feat_size=7,\n                num_classes=80,\n                bbox_coder=dict(\n                    type='DeltaXYWHBBoxCoder',\n                    target_means=[0., 0., 0., 0.],\n                    target_stds=[0.05, 0.05, 0.1, 0.1]),\n                reg_class_agnostic=True,\n                loss_cls=dict(\n                    type='CrossEntropyLoss',\n                    use_sigmoid=False,\n                    loss_weight=1.0),\n                loss_bbox=dict(type='SmoothL1Loss', beta=1.0,\n                               loss_weight=1.0)),\n            dict(\n                type='Shared2FCBBoxHead',\n                in_channels=256,\n                fc_out_channels=1024,\n                roi_feat_size=7,\n                num_classes=80,\n                bbox_coder=dict(\n                    type='DeltaXYWHBBoxCoder',\n                    target_means=[0., 0., 0., 0.],\n                    target_stds=[0.033, 0.033, 0.067, 0.067]),\n                reg_class_agnostic=True,\n                loss_cls=dict(\n                    type='CrossEntropyLoss',\n                    use_sigmoid=False,\n                    loss_weight=1.0),\n                loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0))\n        ],\n        mask_roi_extractor=dict(\n            type='SingleRoIExtractor',\n            roi_layer=dict(type='RoIAlign', output_size=14, sampling_ratio=0),\n            out_channels=256,\n            featmap_strides=[4, 8, 16, 32]),\n        mask_head=dict(\n            type='FCNMaskHead',\n            num_convs=4,\n            in_channels=256,\n            conv_out_channels=256,\n            num_classes=80,\n            loss_mask=dict(\n                type='CrossEntropyLoss', use_mask=True, loss_weight=1.0))),\n    # model training and testing settings\n    train_cfg=dict(\n        rpn=dict(\n            assigner=dict(\n                type='MaxIoUAssigner',\n                pos_iou_thr=0.7,\n                neg_iou_thr=0.3,\n                min_pos_iou=0.3,\n                match_low_quality=True,\n                ignore_iof_thr=-1),\n            sampler=dict(\n                type='RandomSampler',\n                num=256,\n                pos_fraction=0.5,\n                neg_pos_ub=-1,\n                add_gt_as_proposals=False),\n            allowed_border=0,\n            pos_weight=-1,\n            debug=False),\n        rpn_proposal=dict(\n            nms_pre=2000,\n            max_per_img=2000,\n            nms=dict(type='nms', iou_threshold=0.7),\n            min_bbox_size=0),\n        rcnn=[\n            dict(\n                assigner=dict(\n                    type='MaxIoUAssigner',\n                    pos_iou_thr=0.5,\n                    neg_iou_thr=0.5,\n                    min_pos_iou=0.5,\n                    match_low_quality=False,\n                    ignore_iof_thr=-1),\n                sampler=dict(\n                    type='RandomSampler',\n                    num=512,\n                    pos_fraction=0.25,\n                    neg_pos_ub=-1,\n                    add_gt_as_proposals=True),\n                mask_size=28,\n                pos_weight=-1,\n                debug=False),\n            dict(\n                assigner=dict(\n                    type='MaxIoUAssigner',\n                    pos_iou_thr=0.6,\n                    neg_iou_thr=0.6,\n                    min_pos_iou=0.6,\n                    match_low_quality=False,\n                    ignore_iof_thr=-1),\n                sampler=dict(\n                    type='RandomSampler',\n                    num=512,\n                    pos_fraction=0.25,\n                    neg_pos_ub=-1,\n                    add_gt_as_proposals=True),\n                mask_size=28,\n                pos_weight=-1,\n                debug=False),\n            dict(\n                assigner=dict(\n                    type='MaxIoUAssigner',\n                    pos_iou_thr=0.7,\n                    neg_iou_thr=0.7,\n                    min_pos_iou=0.7,\n                    match_low_quality=False,\n                    ignore_iof_thr=-1),\n                sampler=dict(\n                    type='RandomSampler',\n                    num=512,\n                    pos_fraction=0.25,\n                    neg_pos_ub=-1,\n                    add_gt_as_proposals=True),\n                mask_size=28,\n                pos_weight=-1,\n                debug=False)\n        ]),\n    test_cfg=dict(\n        rpn=dict(\n            nms_pre=1000,\n            max_per_img=1000,\n            nms=dict(type='nms', iou_threshold=0.7),\n            min_bbox_size=0),\n        rcnn=dict(\n            score_thr=0.05,\n            nms=dict(type='nms', iou_threshold=0.5),\n            max_per_img=100,\n            mask_thr_binary=0.5)))\n\n"
  },
  {
    "path": "downstream_tasks/detection/evaluation/object_detection/configs/_base_/models/cascade_mask_rcnn_swin_fpn.py",
    "content": "# model settings\nmodel = dict(\n    type='CascadeRCNN',\n    pretrained=None,\n    backbone=dict(\n        type='SwinTransformer',\n        embed_dim=96,\n        depths=[2, 2, 6, 2],\n        num_heads=[3, 6, 12, 24],\n        window_size=7,\n        mlp_ratio=4.,\n        qkv_bias=True,\n        qk_scale=None,\n        drop_rate=0.,\n        attn_drop_rate=0.,\n        drop_path_rate=0.2,\n        ape=False,\n        patch_norm=True,\n        out_indices=(0, 1, 2, 3),\n        use_checkpoint=False),\n    neck=dict(\n        type='FPN',\n        in_channels=[96, 192, 384, 768],\n        out_channels=256,\n        num_outs=5),\n    rpn_head=dict(\n        type='RPNHead',\n        in_channels=256,\n        feat_channels=256,\n        anchor_generator=dict(\n            type='AnchorGenerator',\n            scales=[8],\n            ratios=[0.5, 1.0, 2.0],\n            strides=[4, 8, 16, 32, 64]),\n        bbox_coder=dict(\n            type='DeltaXYWHBBoxCoder',\n            target_means=[.0, .0, .0, .0],\n            target_stds=[1.0, 1.0, 1.0, 1.0]),\n        loss_cls=dict(\n            type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),\n        loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0)),\n    roi_head=dict(\n        type='CascadeRoIHead',\n        num_stages=3,\n        stage_loss_weights=[1, 0.5, 0.25],\n        bbox_roi_extractor=dict(\n            type='SingleRoIExtractor',\n            roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),\n            out_channels=256,\n            featmap_strides=[4, 8, 16, 32]),\n        bbox_head=[\n            dict(\n                type='Shared2FCBBoxHead',\n                in_channels=256,\n                fc_out_channels=1024,\n                roi_feat_size=7,\n                num_classes=80,\n                bbox_coder=dict(\n                    type='DeltaXYWHBBoxCoder',\n                    target_means=[0., 0., 0., 0.],\n                    target_stds=[0.1, 0.1, 0.2, 0.2]),\n                reg_class_agnostic=True,\n                loss_cls=dict(\n                    type='CrossEntropyLoss',\n                    use_sigmoid=False,\n                    loss_weight=1.0),\n                loss_bbox=dict(type='SmoothL1Loss', beta=1.0,\n                               loss_weight=1.0)),\n            dict(\n                type='Shared2FCBBoxHead',\n                in_channels=256,\n                fc_out_channels=1024,\n                roi_feat_size=7,\n                num_classes=80,\n                bbox_coder=dict(\n                    type='DeltaXYWHBBoxCoder',\n                    target_means=[0., 0., 0., 0.],\n                    target_stds=[0.05, 0.05, 0.1, 0.1]),\n                reg_class_agnostic=True,\n                loss_cls=dict(\n                    type='CrossEntropyLoss',\n                    use_sigmoid=False,\n                    loss_weight=1.0),\n                loss_bbox=dict(type='SmoothL1Loss', beta=1.0,\n                               loss_weight=1.0)),\n            dict(\n                type='Shared2FCBBoxHead',\n                in_channels=256,\n                fc_out_channels=1024,\n                roi_feat_size=7,\n                num_classes=80,\n                bbox_coder=dict(\n                    type='DeltaXYWHBBoxCoder',\n                    target_means=[0., 0., 0., 0.],\n                    target_stds=[0.033, 0.033, 0.067, 0.067]),\n                reg_class_agnostic=True,\n                loss_cls=dict(\n                    type='CrossEntropyLoss',\n                    use_sigmoid=False,\n                    loss_weight=1.0),\n                loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0))\n        ],\n        mask_roi_extractor=dict(\n            type='SingleRoIExtractor',\n            roi_layer=dict(type='RoIAlign', output_size=14, sampling_ratio=0),\n            out_channels=256,\n            featmap_strides=[4, 8, 16, 32]),\n        mask_head=dict(\n            type='FCNMaskHead',\n            num_convs=4,\n            in_channels=256,\n            conv_out_channels=256,\n            num_classes=80,\n            loss_mask=dict(\n                type='CrossEntropyLoss', use_mask=True, loss_weight=1.0))),\n    # model training and testing settings\n    train_cfg = dict(\n        rpn=dict(\n            assigner=dict(\n                type='MaxIoUAssigner',\n                pos_iou_thr=0.7,\n                neg_iou_thr=0.3,\n                min_pos_iou=0.3,\n                match_low_quality=True,\n                ignore_iof_thr=-1),\n            sampler=dict(\n                type='RandomSampler',\n                num=256,\n                pos_fraction=0.5,\n                neg_pos_ub=-1,\n                add_gt_as_proposals=False),\n            allowed_border=0,\n            pos_weight=-1,\n            debug=False),\n        rpn_proposal=dict(\n            nms_across_levels=False,\n            nms_pre=2000,\n            nms_post=2000,\n            max_per_img=2000,\n            nms=dict(type='nms', iou_threshold=0.7),\n            min_bbox_size=0),\n        rcnn=[\n            dict(\n                assigner=dict(\n                    type='MaxIoUAssigner',\n                    pos_iou_thr=0.5,\n                    neg_iou_thr=0.5,\n                    min_pos_iou=0.5,\n                    match_low_quality=False,\n                    ignore_iof_thr=-1),\n                sampler=dict(\n                    type='RandomSampler',\n                    num=512,\n                    pos_fraction=0.25,\n                    neg_pos_ub=-1,\n                    add_gt_as_proposals=True),\n                mask_size=28,\n                pos_weight=-1,\n                debug=False),\n            dict(\n                assigner=dict(\n                    type='MaxIoUAssigner',\n                    pos_iou_thr=0.6,\n                    neg_iou_thr=0.6,\n                    min_pos_iou=0.6,\n                    match_low_quality=False,\n                    ignore_iof_thr=-1),\n                sampler=dict(\n                    type='RandomSampler',\n                    num=512,\n                    pos_fraction=0.25,\n                    neg_pos_ub=-1,\n                    add_gt_as_proposals=True),\n                mask_size=28,\n                pos_weight=-1,\n                debug=False),\n            dict(\n                assigner=dict(\n                    type='MaxIoUAssigner',\n                    pos_iou_thr=0.7,\n                    neg_iou_thr=0.7,\n                    min_pos_iou=0.7,\n                    match_low_quality=False,\n                    ignore_iof_thr=-1),\n                sampler=dict(\n                    type='RandomSampler',\n                    num=512,\n                    pos_fraction=0.25,\n                    neg_pos_ub=-1,\n                    add_gt_as_proposals=True),\n                mask_size=28,\n                pos_weight=-1,\n                debug=False)\n        ]),\n    test_cfg = dict(\n        rpn=dict(\n            nms_across_levels=False,\n            nms_pre=1000,\n            nms_post=1000,\n            max_per_img=1000,\n            nms=dict(type='nms', iou_threshold=0.7),\n            min_bbox_size=0),\n        rcnn=dict(\n            score_thr=0.05,\n            nms=dict(type='nms', iou_threshold=0.5),\n            max_per_img=100,\n            mask_thr_binary=0.5)))"
  },
  {
    "path": "downstream_tasks/detection/evaluation/object_detection/configs/_base_/models/cascade_mask_rcnn_vit_fpn.py",
    "content": "# model settings\nmodel = dict(\n    type='CascadeRCNN',\n    pretrained=None,\n    backbone=dict(\n        type='VisionTransformer',\n        img_size=[672, 1092],\n        patch_size=16,\n        embed_dim=384,\n        depth=12,\n        num_heads=6,\n        mlp_ratio=4.,\n        qkv_bias=True,\n        drop_path_rate=0.1,\n        out_indices=(3, 5, 7, 11),\n        use_checkpoint=False),\n    neck=dict(\n        type='FPN',\n        in_channels=[384, 384, 384, 384],\n        out_channels=256,\n        num_outs=5),\n    rpn_head=dict(\n        type='RPNHead',\n        in_channels=256,\n        feat_channels=256,\n        anchor_generator=dict(\n            type='AnchorGenerator',\n            scales=[8],\n            ratios=[0.5, 1.0, 2.0],\n            strides=[4, 8, 16, 32, 64]),\n        bbox_coder=dict(\n            type='DeltaXYWHBBoxCoder',\n            target_means=[.0, .0, .0, .0],\n            target_stds=[1.0, 1.0, 1.0, 1.0]),\n        loss_cls=dict(\n            type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),\n        loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0)),\n    roi_head=dict(\n        type='CascadeRoIHead',\n        num_stages=3,\n        stage_loss_weights=[1, 0.5, 0.25],\n        bbox_roi_extractor=dict(\n            type='SingleRoIExtractor',\n            roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),\n            out_channels=256,\n            featmap_strides=[4, 8, 16, 32]),\n        bbox_head=[\n            dict(\n                type='Shared2FCBBoxHead',\n                in_channels=256,\n                fc_out_channels=1024,\n                roi_feat_size=7,\n                num_classes=80,\n                bbox_coder=dict(\n                    type='DeltaXYWHBBoxCoder',\n                    target_means=[0., 0., 0., 0.],\n                    target_stds=[0.1, 0.1, 0.2, 0.2]),\n                reg_class_agnostic=True,\n                loss_cls=dict(\n                    type='CrossEntropyLoss',\n                    use_sigmoid=False,\n                    loss_weight=1.0),\n                loss_bbox=dict(type='SmoothL1Loss', beta=1.0,\n                               loss_weight=1.0)),\n            dict(\n                type='Shared2FCBBoxHead',\n                in_channels=256,\n                fc_out_channels=1024,\n                roi_feat_size=7,\n                num_classes=80,\n                bbox_coder=dict(\n                    type='DeltaXYWHBBoxCoder',\n                    target_means=[0., 0., 0., 0.],\n                    target_stds=[0.05, 0.05, 0.1, 0.1]),\n                reg_class_agnostic=True,\n                loss_cls=dict(\n                    type='CrossEntropyLoss',\n                    use_sigmoid=False,\n                    loss_weight=1.0),\n                loss_bbox=dict(type='SmoothL1Loss', beta=1.0,\n                               loss_weight=1.0)),\n            dict(\n                type='Shared2FCBBoxHead',\n                in_channels=256,\n                fc_out_channels=1024,\n                roi_feat_size=7,\n                num_classes=80,\n                bbox_coder=dict(\n                    type='DeltaXYWHBBoxCoder',\n                    target_means=[0., 0., 0., 0.],\n                    target_stds=[0.033, 0.033, 0.067, 0.067]),\n                reg_class_agnostic=True,\n                loss_cls=dict(\n                    type='CrossEntropyLoss',\n                    use_sigmoid=False,\n                    loss_weight=1.0),\n                loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0))\n        ],\n        mask_roi_extractor=dict(\n            type='SingleRoIExtractor',\n            roi_layer=dict(type='RoIAlign', output_size=14, sampling_ratio=0),\n            out_channels=256,\n            featmap_strides=[4, 8, 16, 32]),\n        mask_head=dict(\n            type='FCNMaskHead',\n            num_convs=4,\n            in_channels=256,\n            conv_out_channels=256,\n            num_classes=80,\n            loss_mask=dict(\n                type='CrossEntropyLoss', use_mask=True, loss_weight=1.0))),\n    # model training and testing settings\n    train_cfg = dict(\n        rpn=dict(\n            assigner=dict(\n                type='MaxIoUAssigner',\n                pos_iou_thr=0.7,\n                neg_iou_thr=0.3,\n                min_pos_iou=0.3,\n                match_low_quality=True,\n                ignore_iof_thr=-1),\n            sampler=dict(\n                type='RandomSampler',\n                num=256,\n                pos_fraction=0.5,\n                neg_pos_ub=-1,\n                add_gt_as_proposals=False),\n            allowed_border=0,\n            pos_weight=-1,\n            debug=False),\n        rpn_proposal=dict(\n            nms_across_levels=False,\n            nms_pre=2000,\n            nms_post=2000,\n            max_per_img=2000,\n            nms=dict(type='nms', iou_threshold=0.7),\n            min_bbox_size=0),\n        rcnn=[\n            dict(\n                assigner=dict(\n                    type='MaxIoUAssigner',\n                    pos_iou_thr=0.5,\n                    neg_iou_thr=0.5,\n                    min_pos_iou=0.5,\n                    match_low_quality=False,\n                    ignore_iof_thr=-1),\n                sampler=dict(\n                    type='RandomSampler',\n                    num=512,\n                    pos_fraction=0.25,\n                    neg_pos_ub=-1,\n                    add_gt_as_proposals=True),\n                mask_size=28,\n                pos_weight=-1,\n                debug=False),\n            dict(\n                assigner=dict(\n                    type='MaxIoUAssigner',\n                    pos_iou_thr=0.6,\n                    neg_iou_thr=0.6,\n                    min_pos_iou=0.6,\n                    match_low_quality=False,\n                    ignore_iof_thr=-1),\n                sampler=dict(\n                    type='RandomSampler',\n                    num=512,\n                    pos_fraction=0.25,\n                    neg_pos_ub=-1,\n                    add_gt_as_proposals=True),\n                mask_size=28,\n                pos_weight=-1,\n                debug=False),\n            dict(\n                assigner=dict(\n                    type='MaxIoUAssigner',\n                    pos_iou_thr=0.7,\n                    neg_iou_thr=0.7,\n                    min_pos_iou=0.7,\n                    match_low_quality=False,\n                    ignore_iof_thr=-1),\n                sampler=dict(\n                    type='RandomSampler',\n                    num=512,\n                    pos_fraction=0.25,\n                    neg_pos_ub=-1,\n                    add_gt_as_proposals=True),\n                mask_size=28,\n                pos_weight=-1,\n                debug=False)\n        ]),\n    test_cfg = dict(\n        rpn=dict(\n            nms_across_levels=False,\n            nms_pre=1000,\n            nms_post=1000,\n            max_per_img=1000,\n            nms=dict(type='nms', iou_threshold=0.7),\n            min_bbox_size=0),\n        rcnn=dict(\n            score_thr=0.05,\n            nms=dict(type='nms', iou_threshold=0.5),\n            max_per_img=100,\n            mask_thr_binary=0.5)))"
  },
  {
    "path": "downstream_tasks/detection/evaluation/object_detection/configs/_base_/models/mask_rcnn_r50_fpn.py",
    "content": "# model settings\nmodel = dict(\n    type='MaskRCNN',\n    backbone=dict(\n        type='ResNet',\n        depth=50,\n        num_stages=4,\n        out_indices=(0, 1, 2, 3),\n        frozen_stages=1,\n        norm_cfg=dict(type='BN', requires_grad=True),\n        norm_eval=True,\n        style='pytorch',\n        init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),\n    neck=dict(\n        type='FPN',\n        in_channels=[256, 512, 1024, 2048],\n        out_channels=256,\n        num_outs=5),\n    rpn_head=dict(\n        type='RPNHead',\n        in_channels=256,\n        feat_channels=256,\n        anchor_generator=dict(\n            type='AnchorGenerator',\n            scales=[8],\n            ratios=[0.5, 1.0, 2.0],\n            strides=[4, 8, 16, 32, 64]),\n        bbox_coder=dict(\n            type='DeltaXYWHBBoxCoder',\n            target_means=[.0, .0, .0, .0],\n            target_stds=[1.0, 1.0, 1.0, 1.0]),\n        loss_cls=dict(\n            type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),\n        loss_bbox=dict(type='L1Loss', loss_weight=1.0)),\n    roi_head=dict(\n        type='StandardRoIHead',\n        bbox_roi_extractor=dict(\n            type='SingleRoIExtractor',\n            roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),\n            out_channels=256,\n            featmap_strides=[4, 8, 16, 32]),\n        bbox_head=dict(\n            type='Shared2FCBBoxHead',\n            in_channels=256,\n            fc_out_channels=1024,\n            roi_feat_size=7,\n            num_classes=80,\n            bbox_coder=dict(\n                type='DeltaXYWHBBoxCoder',\n                target_means=[0., 0., 0., 0.],\n                target_stds=[0.1, 0.1, 0.2, 0.2]),\n            reg_class_agnostic=False,\n            loss_cls=dict(\n                type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),\n            loss_bbox=dict(type='L1Loss', loss_weight=1.0)),\n        mask_roi_extractor=dict(\n            type='SingleRoIExtractor',\n            roi_layer=dict(type='RoIAlign', output_size=14, sampling_ratio=0),\n            out_channels=256,\n            featmap_strides=[4, 8, 16, 32]),\n        mask_head=dict(\n            type='FCNMaskHead',\n            num_convs=4,\n            in_channels=256,\n            conv_out_channels=256,\n            num_classes=80,\n            loss_mask=dict(\n                type='CrossEntropyLoss', use_mask=True, loss_weight=1.0))),\n    # model training and testing settings\n    train_cfg=dict(\n        rpn=dict(\n            assigner=dict(\n                type='MaxIoUAssigner',\n                pos_iou_thr=0.7,\n                neg_iou_thr=0.3,\n                min_pos_iou=0.3,\n                match_low_quality=True,\n                ignore_iof_thr=-1),\n            sampler=dict(\n                type='RandomSampler',\n                num=256,\n                pos_fraction=0.5,\n                neg_pos_ub=-1,\n                add_gt_as_proposals=False),\n            allowed_border=-1,\n            pos_weight=-1,\n            debug=False),\n        rpn_proposal=dict(\n            nms_pre=2000,\n            max_per_img=1000,\n            nms=dict(type='nms', iou_threshold=0.7),\n            min_bbox_size=0),\n        rcnn=dict(\n            assigner=dict(\n                type='MaxIoUAssigner',\n                pos_iou_thr=0.5,\n                neg_iou_thr=0.5,\n                min_pos_iou=0.5,\n                match_low_quality=True,\n                ignore_iof_thr=-1),\n            sampler=dict(\n                type='RandomSampler',\n                num=512,\n                pos_fraction=0.25,\n                neg_pos_ub=-1,\n                add_gt_as_proposals=True),\n            mask_size=28,\n            pos_weight=-1,\n            debug=False)),\n    test_cfg=dict(\n        rpn=dict(\n            nms_pre=1000,\n            max_per_img=1000,\n            nms=dict(type='nms', iou_threshold=0.7),\n            min_bbox_size=0),\n        rcnn=dict(\n            score_thr=0.05,\n            nms=dict(type='nms', iou_threshold=0.5),\n            max_per_img=100,\n            mask_thr_binary=0.5)))\n\n"
  },
  {
    "path": "downstream_tasks/detection/evaluation/object_detection/configs/_base_/models/mask_rcnn_vit_fpn.py",
    "content": "# model settings\nmodel = dict(\n    type='MaskRCNN',\n    pretrained=None,\n    backbone=dict(\n        type='VisionTransformer',\n        img_size=[672, 1092],\n        patch_size=16,\n        embed_dim=384,\n        depth=12,\n        num_heads=6,\n        mlp_ratio=4.,\n        qkv_bias=True,\n        drop_path_rate=0.1,\n        out_indices=(3, 5, 7, 11),\n        use_checkpoint=False),\n    neck=dict(\n        type='FPN',\n        in_channels=[384, 384, 384, 384],\n        out_channels=256,\n        num_outs=5),\n    rpn_head=dict(\n        type='RPNHead',\n        in_channels=256,\n        feat_channels=256,\n        anchor_generator=dict(\n            type='AnchorGenerator',\n            scales=[8],\n            ratios=[0.5, 1.0, 2.0],\n            strides=[4, 8, 16, 32, 64]),\n        bbox_coder=dict(\n            type='DeltaXYWHBBoxCoder',\n            target_means=[.0, .0, .0, .0],\n            target_stds=[1.0, 1.0, 1.0, 1.0]),\n        loss_cls=dict(\n            type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),\n        loss_bbox=dict(type='L1Loss', loss_weight=1.0)),\n    roi_head=dict(\n        type='StandardRoIHead',\n        bbox_roi_extractor=dict(\n            type='SingleRoIExtractor',\n            roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),\n            out_channels=256,\n            featmap_strides=[4, 8, 16, 32]),\n        bbox_head=dict(\n            type='Shared2FCBBoxHead',\n            in_channels=256,\n            fc_out_channels=1024,\n            roi_feat_size=7,\n            num_classes=80,\n            bbox_coder=dict(\n                type='DeltaXYWHBBoxCoder',\n                target_means=[0., 0., 0., 0.],\n                target_stds=[0.1, 0.1, 0.2, 0.2]),\n            reg_class_agnostic=False,\n            loss_cls=dict(\n                type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),\n            loss_bbox=dict(type='L1Loss', loss_weight=1.0)),\n        mask_roi_extractor=dict(\n            type='SingleRoIExtractor',\n            roi_layer=dict(type='RoIAlign', output_size=14, sampling_ratio=0),\n            out_channels=256,\n            featmap_strides=[4, 8, 16, 32]),\n        mask_head=dict(\n            type='FCNMaskHead',\n            num_convs=4,\n            in_channels=256,\n            conv_out_channels=256,\n            num_classes=80,\n            loss_mask=dict(\n                type='CrossEntropyLoss', use_mask=True, loss_weight=1.0))),\n    # model training and testing settings\n    train_cfg=dict(\n        rpn=dict(\n            assigner=dict(\n                type='MaxIoUAssigner',\n                pos_iou_thr=0.7,\n                neg_iou_thr=0.3,\n                min_pos_iou=0.3,\n                match_low_quality=True,\n                ignore_iof_thr=-1),\n            sampler=dict(\n                type='RandomSampler',\n                num=256,\n                pos_fraction=0.5,\n                neg_pos_ub=-1,\n                add_gt_as_proposals=False),\n            allowed_border=-1,\n            pos_weight=-1,\n            debug=False),\n        rpn_proposal=dict(\n            nms_across_levels=False,\n            nms_pre=2000,\n            nms_post=2000,\n            max_per_img=2000,\n            nms=dict(type='nms', iou_threshold=0.7),\n            min_bbox_size=0),\n        rcnn=dict(\n            assigner=dict(\n                type='MaxIoUAssigner',\n                pos_iou_thr=0.5,\n                neg_iou_thr=0.5,\n                min_pos_iou=0.5,\n                match_low_quality=True,\n                ignore_iof_thr=-1),\n            sampler=dict(\n                type='RandomSampler',\n                num=512,\n                pos_fraction=0.25,\n                neg_pos_ub=-1,\n                add_gt_as_proposals=True),\n            mask_size=28,\n            pos_weight=-1,\n            debug=False)),\n    test_cfg=dict(\n        rpn=dict(\n            nms_across_levels=False,\n            nms_pre=1000,\n            nms_post=1000,\n            max_per_img=1000,\n            nms=dict(type='nms', iou_threshold=0.7),\n            min_bbox_size=0),\n        rcnn=dict(\n            score_thr=0.05,\n            nms=dict(type='nms', iou_threshold=0.5),\n            max_per_img=100,\n            mask_thr_binary=0.5)))\n\n"
  },
  {
    "path": "downstream_tasks/detection/evaluation/object_detection/configs/_base_/schedules/schedule_1x.py",
    "content": "# optimizer\noptimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)\noptimizer_config = dict(grad_clip=None)\n# learning policy\nlr_config = dict(\n    policy='step',\n    warmup='linear',\n    warmup_iters=500,\n    warmup_ratio=0.001,\n    step=[8, 11])\nrunner = dict(type='EpochBasedRunner', max_epochs=12)"
  },
  {
    "path": "downstream_tasks/detection/evaluation/object_detection/configs/mask_rcnn/vit_base_giou_4conv1f_coco_maskrcnn_1x_cae_sincos_init0.1_lr00003.py",
    "content": "# Copyright (c) ByteDance, Inc. and its affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the license found in the\n# LICENSE file in the root directory of this source tree.\n\n\"\"\"\nMostly copy-paste from timm, mmdet, and swin code bases\nhttps://github.com/rwightman/pytorch-image-models/tree/master/timm\nhttps://github.com/open-mmlab/mmdetection\nhttps://github.com/SwinTransformer/Swin-Transformer-Object-Detection\n\"\"\"\n\n_base_ = [\n    '../_base_/models/mask_rcnn_vit_fpn.py',\n    '../_base_/datasets/coco_instance.py',\n    '../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'\n]\n\nmodel = dict(\n    backbone=dict(\n        embed_dim=768,\n        depth=12,\n        num_heads=12,\n        init_values=0.1, \n        mlp_ratio=4.,\n        drop_path_rate=0.2, #see if 0.1 larger than vit-small is better\n\t\tuse_abs_pos_emb=False,\n\t\tuse_sincos_pos_emb=True,\n\t\tuse_rel_pos_bias=False,\n    ),\n    neck=dict(in_channels=[768, 768, 768, 768]),\n    roi_head=dict(\n        bbox_head=dict(\n                type='ConvFCBBoxHead',\n                num_shared_convs=4,\n                num_shared_fcs=1,\n                in_channels=256,\n                conv_out_channels=256,\n                fc_out_channels=1024,\n                roi_feat_size=7,\n                num_classes=80,\n                bbox_coder=dict(\n                    type='DeltaXYWHBBoxCoder',\n                    target_means=[0., 0., 0., 0.],\n                    target_stds=[0.1, 0.1, 0.2, 0.2]),\n                reg_class_agnostic=False,\n                reg_decoded_bbox=True,\n                norm_cfg=dict(type='SyncBN', requires_grad=True),\n                loss_cls=dict(\n                    type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),\n                loss_bbox=dict(type='GIoULoss', loss_weight=10.0))))\n\nimg_norm_cfg = dict(\n    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)\n\n# augmentation strategy originates from DETR / Sparse RCNN\ntrain_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(type='LoadAnnotations', with_bbox=True, with_mask=True),\n    dict(type='RandomFlip', flip_ratio=0.5),\n    dict(type='AutoAugment',\n         policies=[\n             [\n                 dict(type='Resize',\n                      img_scale=[(480, 1333), (512, 1333), (544, 1333), (576, 1333),\n                                 (608, 1333), (640, 1333), (672, 1333), (704, 1333),\n                                 (736, 1333), (768, 1333), (800, 1333)],\n                      multiscale_mode='value',\n                      keep_ratio=True)\n             ],\n             [\n                 dict(type='Resize',\n                      img_scale=[(400, 1333), (500, 1333), (600, 1333)],\n                      multiscale_mode='value',\n                      keep_ratio=True),\n                 dict(type='RandomCrop',\n                      crop_type='absolute_range',\n                      crop_size=(384, 600),\n                      allow_negative_crop=True),\n                 dict(type='Resize',\n                      img_scale=[(480, 1333), (512, 1333), (544, 1333),\n                                 (576, 1333), (608, 1333), (640, 1333),\n                                 (672, 1333), (704, 1333), (736, 1333),\n                                 (768, 1333), (800, 1333)],\n                      multiscale_mode='value',\n                      override=True,\n                      keep_ratio=True)\n             ]\n         ]),\n    dict(type='Normalize', **img_norm_cfg),\n    dict(type='Pad', size_divisor=32),\n    dict(type='DefaultFormatBundle'),\n    dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),\n]\ndata = dict(\n    \t\tsamples_per_gpu=2,\n   \t\t\tworkers_per_gpu=2,\n\t\t\ttrain=dict(pipeline=train_pipeline))\n\noptimizer = dict(_delete_=True, type='AdamW', lr=0.0003, betas=(0.9, 0.999), weight_decay=0.05,\n                 constructor='LayerDecayOptimizerConstructor', \n                 paramwise_cfg=dict(num_layers=12, layer_decay_rate=0.75))\n\nlr_config = dict(step=[9, 11])\nrunner = dict(type='EpochBasedRunnerAmp', max_epochs=12)\n\n# do not use mmdet version fp16\nfp16 = None\noptimizer_config = dict(\n    type=\"DistOptimizerHook\",\n    update_interval=1,\n    grad_clip=None,\n    coalesce=True,\n    bucket_size_mb=-1,\n    use_fp16=True,\n)\n"
  },
  {
    "path": "downstream_tasks/detection/evaluation/object_detection/configs/mask_rcnn/vit_large_giou_4conv1f_coco_maskrcnn_1x_cae_sincos_init0.1_lr00002_lrdr0.85_dp0.2.py",
    "content": "#Copyright (c) ByteDance, Inc. and its affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the license found in the\n# LICENSE file in the root directory of this source tree.\n\n\"\"\"\nMostly copy-paste from timm, mmdet, and swin code bases\nhttps://github.com/rwightman/pytorch-image-models/tree/master/timm\nhttps://github.com/open-mmlab/mmdetection\nhttps://github.com/SwinTransformer/Swin-Transformer-Object-Detection\n\"\"\"\n\n_base_ = [\n    '../_base_/models/mask_rcnn_vit_fpn.py',\n    '../_base_/datasets/coco_instance.py',\n    '../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'\n]\n\nfind_unused_parameters = False\nmodel = dict(\n    backbone=dict(\n        embed_dim=1024,\n        depth=24,\n        num_heads=16,\n        init_values=0.00001, \n        mlp_ratio=4.,\n        drop_path_rate=0.2, #see if 0.1 larger than vit-small is better\n\t\tuse_abs_pos_emb=False,\n\t\tuse_sincos_pos_emb=True,\n\t\tuse_rel_pos_bias=False,\n\t\tout_indices=[7, 11, 15, 23],\n    ),\n    neck=dict(in_channels=[1024, 1024, 1024, 1024]),\n    roi_head=dict(\n        bbox_head=dict(\n                type='ConvFCBBoxHead',\n                num_shared_convs=4,\n                num_shared_fcs=1,\n                in_channels=256,\n                conv_out_channels=256,\n                fc_out_channels=1024,\n                roi_feat_size=7,\n                num_classes=80,\n                bbox_coder=dict(\n                    type='DeltaXYWHBBoxCoder',\n                    target_means=[0., 0., 0., 0.],\n                    target_stds=[0.1, 0.1, 0.2, 0.2]),\n                reg_class_agnostic=False,\n                reg_decoded_bbox=True,\n                norm_cfg=dict(type='SyncBN', requires_grad=True),\n                loss_cls=dict(\n                    type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),\n                loss_bbox=dict(type='GIoULoss', loss_weight=10.0))))\n\nimg_norm_cfg = dict(\n    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)\n\n# augmentation strategy originates from DETR / Sparse RCNN\ntrain_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(type='LoadAnnotations', with_bbox=True, with_mask=True),\n    dict(type='RandomFlip', flip_ratio=0.5),\n    dict(type='AutoAugment',\n         policies=[\n             [\n                 dict(type='Resize',\n                      img_scale=[(480, 1333), (512, 1333), (544, 1333), (576, 1333),\n                                 (608, 1333), (640, 1333), (672, 1333), (704, 1333),\n                                 (736, 1333), (768, 1333), (800, 1333)],\n                      multiscale_mode='value',\n                      keep_ratio=True)\n             ],\n             [\n                 dict(type='Resize',\n                      img_scale=[(400, 1333), (500, 1333), (600, 1333)],\n                      multiscale_mode='value',\n                      keep_ratio=True),\n                 dict(type='RandomCrop',\n                      crop_type='absolute_range',\n                      crop_size=(384, 600),\n                      allow_negative_crop=True),\n                 dict(type='Resize',\n                      img_scale=[(480, 1333), (512, 1333), (544, 1333),\n                                 (576, 1333), (608, 1333), (640, 1333),\n                                 (672, 1333), (704, 1333), (736, 1333),\n                                 (768, 1333), (800, 1333)],\n                      multiscale_mode='value',\n                      override=True,\n                      keep_ratio=True)\n             ]\n         ]),\n    dict(type='Normalize', **img_norm_cfg),\n    dict(type='Pad', size_divisor=32),\n    dict(type='DefaultFormatBundle'),\n    dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),\n]\ndata = dict(\n    \t\tsamples_per_gpu=2,\n   \t\t\tworkers_per_gpu=2,\n\t\t\ttrain=dict(pipeline=train_pipeline))\n\noptimizer = dict(_delete_=True, type='AdamW', lr=0.0002, betas=(0.9, 0.999), weight_decay=0.05,\n                 constructor='LayerDecayOptimizerConstructor', \n                 paramwise_cfg=dict(num_layers=24, layer_decay_rate=0.85))\n\nlr_config = dict(step=[9, 11])\nrunner = dict(type='EpochBasedRunnerAmp', max_epochs=12)\n\n# do not use mmdet version fp16\nfp16 = None\noptimizer_config = dict(\n    type=\"DistOptimizerHook\",\n    update_interval=1,\n    grad_clip=None,\n    coalesce=True,\n    bucket_size_mb=-1,\n    use_fp16=True,\n)\n"
  },
  {
    "path": "downstream_tasks/detection/evaluation/object_detection/mmcv_custom/__init__.py",
    "content": "# -*- coding: utf-8 -*-\n\nfrom .checkpoint import load_checkpoint\nfrom .layer_decay_optimizer_constructor import LayerDecayOptimizerConstructor\nfrom .register_backbone import VisionTransformer\n\n__all__ = ['load_checkpoint', 'LayerDecayOptimizerConstructor']\n"
  },
  {
    "path": "downstream_tasks/detection/evaluation/object_detection/mmcv_custom/checkpoint.py",
    "content": "# Copyright (c) ByteDance, Inc. and its affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the license found in the\n# LICENSE file in the root directory of this source tree.\n\n\"\"\"\nCopy-paste from mmcv library:\nhttps://github.com/open-mmlab/mmcv/\n\"\"\"\n\nimport io\nimport os\nimport os.path as osp\nimport pkgutil\nimport time\nimport warnings\nfrom collections import OrderedDict\nfrom importlib import import_module\nfrom tempfile import TemporaryDirectory\n\nimport torch\nimport torchvision\nfrom torch.optim import Optimizer\nfrom torch.nn import functional as F\n\nimport mmcv\nfrom mmcv.fileio import FileClient\nfrom mmcv.fileio import load as load_file\nfrom mmcv.parallel import is_module_wrapper\nfrom mmcv.utils import mkdir_or_exist\nfrom mmcv.runner import get_dist_info\n\nfrom scipy import interpolate\nimport numpy as np\nimport math\n\nENV_MMCV_HOME = 'MMCV_HOME'\nENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'\nDEFAULT_CACHE_DIR = '~/.cache'\n\n\ndef _get_mmcv_home():\n    mmcv_home = os.path.expanduser(\n        os.getenv(\n            ENV_MMCV_HOME,\n            os.path.join(\n                os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'mmcv')))\n\n    mkdir_or_exist(mmcv_home)\n    return mmcv_home\n\n\ndef load_state_dict(module, state_dict, strict=False, logger=None):\n    \"\"\"Load state_dict to a module.\n\n    This method is modified from :meth:`torch.nn.Module.load_state_dict`.\n    Default value for ``strict`` is set to ``False`` and the message for\n    param mismatch will be shown even if strict is False.\n\n    Args:\n        module (Module): Module that receives the state_dict.\n        state_dict (OrderedDict): Weights.\n        strict (bool): whether to strictly enforce that the keys\n            in :attr:`state_dict` match the keys returned by this module's\n            :meth:`~torch.nn.Module.state_dict` function. Default: ``False``.\n        logger (:obj:`logging.Logger`, optional): Logger to log the error\n            message. If not specified, print function will be used.\n    \"\"\"\n    unexpected_keys = []\n    all_missing_keys = []\n    err_msg = []\n\n    metadata = getattr(state_dict, '_metadata', None)\n    state_dict = state_dict.copy()\n    if metadata is not None:\n        state_dict._metadata = metadata\n\n    # use _load_from_state_dict to enable checkpoint version control\n    def load(module, prefix=''):\n        # recursively check parallel module in case that the model has a\n        # complicated structure, e.g., nn.Module(nn.Module(DDP))\n        if is_module_wrapper(module):\n            module = module.module\n        local_metadata = {} if metadata is None else metadata.get(\n            prefix[:-1], {})\n        module._load_from_state_dict(state_dict, prefix, local_metadata, True,\n                                     all_missing_keys, unexpected_keys,\n                                     err_msg)\n        for name, child in module._modules.items():\n            if child is not None:\n                load(child, prefix + name + '.')\n\n    load(module)\n    load = None  # break load->load reference cycle\n\n    # ignore \"num_batches_tracked\" of BN layers\n    missing_keys = [\n        key for key in all_missing_keys if 'num_batches_tracked' not in key\n    ]\n\n    if unexpected_keys:\n        err_msg.append('unexpected key in source '\n                       f'state_dict: {\", \".join(unexpected_keys)}\\n')\n    if missing_keys:\n        err_msg.append(\n            f'missing keys in source state_dict: {\", \".join(missing_keys)}\\n')\n\n    rank, _ = get_dist_info()\n    if len(err_msg) > 0 and rank == 0:\n        err_msg.insert(\n            0, 'The model and loaded state dict do not match exactly\\n')\n        err_msg = '\\n'.join(err_msg)\n        if strict:\n            raise RuntimeError(err_msg)\n        elif logger is not None:\n            logger.warning(err_msg)\n        else:\n            print(err_msg)\n\n\ndef load_url_dist(url, model_dir=None, map_location=\"cpu\"):\n    \"\"\"In distributed setting, this function only download checkpoint at local\n    rank 0.\"\"\"\n    rank, world_size = get_dist_info()\n    rank = int(os.environ.get('LOCAL_RANK', rank))\n    if rank == 0:\n        checkpoint = model_zoo.load_url(url, model_dir=model_dir, map_location=map_location)\n    if world_size > 1:\n        torch.distributed.barrier()\n        if rank > 0:\n            checkpoint = model_zoo.load_url(url, model_dir=model_dir, map_location=map_location)\n    return checkpoint\n\n\ndef load_pavimodel_dist(model_path, map_location=None):\n    \"\"\"In distributed setting, this function only download checkpoint at local\n    rank 0.\"\"\"\n    try:\n        from pavi import modelscloud\n    except ImportError:\n        raise ImportError(\n            'Please install pavi to load checkpoint from modelcloud.')\n    rank, world_size = get_dist_info()\n    rank = int(os.environ.get('LOCAL_RANK', rank))\n    if rank == 0:\n        model = modelcloud.get(model_path)\n        with TemporaryDirectory() as tmp_dir:\n            downloaded_file = osp.join(tmp_dir, model.name)\n            model.download(downloaded_file)\n            checkpoint = torch.load(downloaded_file, map_location=map_location)\n    if world_size > 1:\n        torch.distributed.barrier()\n        if rank > 0:\n            model = modelcloud.get(model_path)\n            with TemporaryDirectory() as tmp_dir:\n                downloaded_file = osp.join(tmp_dir, model.name)\n                model.download(downloaded_file)\n                checkpoint = torch.load(\n                    downloaded_file, map_location=map_location)\n    return checkpoint\n\n\ndef load_fileclient_dist(filename, backend, map_location):\n    \"\"\"In distributed setting, this function only download checkpoint at local\n    rank 0.\"\"\"\n    rank, world_size = get_dist_info()\n    rank = int(os.environ.get('LOCAL_RANK', rank))\n    allowed_backends = ['ceph']\n    if backend not in allowed_backends:\n        raise ValueError(f'Load from Backend {backend} is not supported.')\n    if rank == 0:\n        fileclient = FileClient(backend=backend)\n        buffer = io.BytesIO(fileclient.get(filename))\n        checkpoint = torch.load(buffer, map_location=map_location)\n    if world_size > 1:\n        torch.distributed.barrier()\n        if rank > 0:\n            fileclient = FileClient(backend=backend)\n            buffer = io.BytesIO(fileclient.get(filename))\n            checkpoint = torch.load(buffer, map_location=map_location)\n    return checkpoint\n\n\ndef get_torchvision_models():\n    model_urls = dict()\n    for _, name, ispkg in pkgutil.walk_packages(torchvision.models.__path__):\n        if ispkg:\n            continue\n        _zoo = import_module(f'torchvision.models.{name}')\n        if hasattr(_zoo, 'model_urls'):\n            _urls = getattr(_zoo, 'model_urls')\n            model_urls.update(_urls)\n    return model_urls\n\n\ndef get_external_models():\n    mmcv_home = _get_mmcv_home()\n    default_json_path = osp.join(mmcv.__path__[0], 'model_zoo/open_mmlab.json')\n    default_urls = load_file(default_json_path)\n    assert isinstance(default_urls, dict)\n    external_json_path = osp.join(mmcv_home, 'open_mmlab.json')\n    if osp.exists(external_json_path):\n        external_urls = load_file(external_json_path)\n        assert isinstance(external_urls, dict)\n        default_urls.update(external_urls)\n\n    return default_urls\n\n\ndef get_mmcls_models():\n    mmcls_json_path = osp.join(mmcv.__path__[0], 'model_zoo/mmcls.json')\n    mmcls_urls = load_file(mmcls_json_path)\n\n    return mmcls_urls\n\n\ndef get_deprecated_model_names():\n    deprecate_json_path = osp.join(mmcv.__path__[0],\n                                   'model_zoo/deprecated.json')\n    deprecate_urls = load_file(deprecate_json_path)\n    assert isinstance(deprecate_urls, dict)\n\n    return deprecate_urls\n\n\ndef _process_mmcls_checkpoint(checkpoint):\n    state_dict = checkpoint['state_dict']\n    new_state_dict = OrderedDict()\n    for k, v in state_dict.items():\n        if k.startswith('backbone.'):\n            new_state_dict[k[9:]] = v\n    new_checkpoint = dict(state_dict=new_state_dict)\n\n    return new_checkpoint\n\n\ndef _load_checkpoint(filename, map_location=None):\n    \"\"\"Load checkpoint from somewhere (modelzoo, file, url).\n\n    Args:\n        filename (str): Accept local filepath, URL, ``torchvision://xxx``,\n            ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for\n            details.\n        map_location (str | None): Same as :func:`torch.load`. Default: None.\n\n    Returns:\n        dict | OrderedDict: The loaded checkpoint. It can be either an\n            OrderedDict storing model weights or a dict containing other\n            information, which depends on the checkpoint.\n    \"\"\"\n    if filename.startswith('modelzoo://'):\n        warnings.warn('The URL scheme of \"modelzoo://\" is deprecated, please '\n                      'use \"torchvision://\" instead')\n        model_urls = get_torchvision_models()\n        model_name = filename[11:]\n        checkpoint = load_url_dist(model_urls[model_name])\n    elif filename.startswith('torchvision://'):\n        model_urls = get_torchvision_models()\n        model_name = filename[14:]\n        checkpoint = load_url_dist(model_urls[model_name])\n    elif filename.startswith('open-mmlab://'):\n        model_urls = get_external_models()\n        model_name = filename[13:]\n        deprecated_urls = get_deprecated_model_names()\n        if model_name in deprecated_urls:\n            warnings.warn(f'open-mmlab://{model_name} is deprecated in favor '\n                          f'of open-mmlab://{deprecated_urls[model_name]}')\n            model_name = deprecated_urls[model_name]\n        model_url = model_urls[model_name]\n        # check if is url\n        if model_url.startswith(('http://', 'https://')):\n            checkpoint = load_url_dist(model_url)\n        else:\n            filename = osp.join(_get_mmcv_home(), model_url)\n            if not osp.isfile(filename):\n                raise IOError(f'{filename} is not a checkpoint file')\n            checkpoint = torch.load(filename, map_location=map_location)\n    elif filename.startswith('mmcls://'):\n        model_urls = get_mmcls_models()\n        model_name = filename[8:]\n        checkpoint = load_url_dist(model_urls[model_name])\n        checkpoint = _process_mmcls_checkpoint(checkpoint)\n    elif filename.startswith(('http://', 'https://')):\n        checkpoint = load_url_dist(filename)\n    elif filename.startswith('pavi://'):\n        model_path = filename[7:]\n        checkpoint = load_pavimodel_dist(model_path, map_location=map_location)\n    elif filename.startswith('s3://'):\n        checkpoint = load_fileclient_dist(\n            filename, backend='ceph', map_location=map_location)\n    else:\n        if not osp.isfile(filename):\n            raise IOError(f'{filename} is not a checkpoint file')\n        checkpoint = torch.load(filename, map_location=map_location)\n    return checkpoint\n\n\ndef cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0,\n                     start_warmup_value=0, warmup_steps=-1):\n    warmup_schedule = np.array([])\n    warmup_iters = warmup_epochs * niter_per_ep\n    if warmup_steps > 0:\n        warmup_iters = warmup_steps\n    print(\"Set warmup steps = %d\" % warmup_iters)\n    if warmup_epochs > 0:\n        warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)\n\n    iters = np.arange(epochs * niter_per_ep - warmup_iters)\n    schedule = np.array(\n        [final_value + 0.5 * (base_value - final_value) * (1 + math.cos(math.pi * i / (len(iters)))) for i in iters])\n\n    schedule = np.concatenate((warmup_schedule, schedule))\n\n    assert len(schedule) == epochs * niter_per_ep\n    return schedule\n\n\ndef load_checkpoint(model,\n                    filename,\n                    map_location='cpu',\n                    strict=False,\n                    logger=None):\n    \"\"\"Load checkpoint from a file or URI.\n\n    Args:\n        model (Module): Module to load checkpoint.\n        filename (str): Accept local filepath, URL, ``torchvision://xxx``,\n            ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for\n            details.\n        map_location (str): Same as :func:`torch.load`.\n        strict (bool): Whether to allow different params for the model and\n            checkpoint.\n        logger (:mod:`logging.Logger` or None): The logger for error message.\n\n    Returns:\n        dict or OrderedDict: The loaded checkpoint.\n    \"\"\"\n    checkpoint = _load_checkpoint(filename, map_location)\n    # OrderedDict is a subclass of dict\n    if not isinstance(checkpoint, dict):\n        raise RuntimeError(\n            f'No state_dict found in checkpoint file {filename}')\n    # get state_dict from checkpoint\n    if 'state_dict' in checkpoint:\n        state_dict = checkpoint['state_dict']\n    elif 'model' in checkpoint:\n        state_dict = checkpoint['model']\n    elif 'module' in checkpoint:\n        state_dict = checkpoint['module']\n    else:\n        state_dict = checkpoint\n    # strip prefix of state_dict\n    if list(state_dict.keys())[0].startswith('module.'):\n        state_dict = {k[7:]: v for k, v in state_dict.items()}\n\n    # for MoBY, load model of online branch\n    if sorted(list(state_dict.keys()))[0].startswith('encoder'):\n        state_dict = {k.replace('encoder.', ''): v for k, v in state_dict.items() if k.startswith('encoder.')}\n\n\n    all_keys = list(state_dict.keys())\n    if all_keys[-1].startswith('encoder_to_decoder') or all_keys[-1].startswith('decoder'):\n        # NOTE: remove all decoder keys\n        all_keys = [key for key in all_keys if key.startswith('encoder.')]\n        for key in all_keys:\n            new_key = key.replace('encoder.','')\n            state_dict[new_key] = state_dict[key]\n            state_dict.pop(key)\n            \n        for key in list(state_dict.keys()):\n            if key.startswith('decoder.'):\n                state_dict.pop(key)\n\n        # NOTE: replace norm with fc_norm\n        for key in list(state_dict.keys()):\n            if key.startswith('norm.'):\n                new_key = key.replace('norm.','fc_norm.')\n                state_dict[new_key] = state_dict[key]\n                state_dict.pop(key)\n\n    # reshape absolute position embedding for Swin\n    if state_dict.get('absolute_pos_embed') is not None:\n        absolute_pos_embed = state_dict['absolute_pos_embed']\n        N1, L, C1 = absolute_pos_embed.size()\n        N2, C2, H, W = model.absolute_pos_embed.size()\n        if N1 != N2 or C1 != C2 or L != H*W:\n            logger.warning(\"Error in loading absolute_pos_embed, pass\")\n        else:\n            state_dict['absolute_pos_embed'] = absolute_pos_embed.view(N2, H, W, C2).permute(0, 3, 1, 2)\n    \n    rank, _ = get_dist_info()\n    if \"rel_pos_bias.relative_position_bias_table\" in state_dict:\n        if rank == 0:\n            rel_pos_bias = state_dict[\"rel_pos_bias.relative_position_bias_table\"]\n            state_dict[\"relative_position_bias_table\"] = rel_pos_bias\n            state_dict.pop(\"rel_pos_bias.relative_position_bias_table\")\n    all_keys = list(state_dict.keys())\n    for key in all_keys:\n        if \"relative_position_index\" in key:\n            state_dict.pop(key)\n\n        if \"relative_position_bias_table\" in key and key in model.state_dict():\n            rel_pos_bias = state_dict[key]\n            src_num_pos, num_attn_heads = rel_pos_bias.size()\n            dst_num_pos, _ = model.state_dict()[key].size()\n            dst_patch_shape = model.patch_embed.patch_shape\n            if dst_patch_shape[0] != dst_patch_shape[1]:\n                num_extra_tokens = dst_num_pos - (dst_patch_shape[0] * 2 - 1) * (dst_patch_shape[1] * 2 - 1)\n\n                src_size = int((src_num_pos - num_extra_tokens) ** 0.5)   # 27 \n                \n                dst_size_0 = dst_patch_shape[0] * 2 - 1   # 42\n                dst_size_1 = dst_patch_shape[1] * 2 - 1   # 68\n\n                if src_size != dst_size_0:\n                    if rank == 0:\n                        print(\"Position interpolate for %s from %dx%d to %dx%d\" % (\n                            key, src_size, src_size, dst_size_0, dst_size_1))\n\n                    extra_tokens = rel_pos_bias[-num_extra_tokens:, :]\n                    rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :]\n\n                    def geometric_progression(a, r, n):\n                        return a * (1.0 - r ** n) / (1.0 - r)\n\n                    left, right = 1.01, 1.5\n                    while right - left > 1e-6:\n                        q = (left + right) / 2.0\n                        gp = geometric_progression(1, q, src_size // 2)\n                        if gp > dst_size_0 // 2:\n                            right = q\n                        else:\n                            left = q\n\n                    dis_0 = []\n                    cur = 1\n                    for i in range(src_size // 2):\n                        dis_0.append(cur)\n                        cur += q ** (i + 1)\n                    \n                    r_ids_0 = [-_ for _ in reversed(dis_0)]\n\n                    top, bottom = 1.01, 1.5\n                    while bottom - top > 1e-6:\n                        q = (top + bottom) / 2.0\n                        gp = geometric_progression(1, q, src_size // 2)\n                        if gp > dst_size_1 // 2:\n                            bottom = q\n                        else:\n                            top = q\n\n                    dis_1 = []\n                    cur = 1\n                    for i in range(src_size // 2):\n                        dis_1.append(cur)\n                        cur += q ** (i + 1)\n                    \n                    r_ids_1 = [-_ for _ in reversed(dis_1)]\n\n                    # if q > 1.13492:\n                    #     q = 1.13492\n\n                    x = r_ids_0 + [0] + dis_0\n                    y = r_ids_1 + [0] + dis_1\n\n                    t_0 = dst_size_0 // 2.0\n                    t_1 = dst_size_1 // 2.0\n\n                    dx = np.arange(-t_0, t_0 + 0.1, 1.0)\n                    dy = np.arange(-t_1, t_1 + 0.1, 1.0)\n\n                    if rank == 0:\n                        print(\"x = {}\".format(x))\n                        print(\"dx = {}\".format(dx))\n\n                    all_rel_pos_bias = []\n\n                    for i in range(num_attn_heads):\n                        z = rel_pos_bias[:, i].view(src_size, src_size).float().numpy()\n                        f = interpolate.interp2d(x, y, z, kind='cubic')\n                        all_rel_pos_bias.append(\n                            torch.Tensor(f(dx, dy)).contiguous().view(-1, 1).to(rel_pos_bias.device))\n\n                    rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1)\n                    new_rel_pos_bias = torch.cat((rel_pos_bias, extra_tokens), dim=0)\n                    state_dict[key] = new_rel_pos_bias\n\n            else:\n                num_extra_tokens = dst_num_pos - (dst_patch_shape[0] * 2 - 1) * (dst_patch_shape[1] * 2 - 1)\n\n                src_size = int((src_num_pos - num_extra_tokens) ** 0.5)   # 27 \n                dst_size = int((dst_num_pos - num_extra_tokens) ** 0.5)   # \n\n                if src_size != dst_size:\n                    if rank == 0:\n                        print(\"Position interpolate for %s from %dx%d to %dx%d\" % (\n                            key, src_size, src_size, dst_size, dst_size))\n\n                    extra_tokens = rel_pos_bias[-num_extra_tokens:, :]\n                    rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :]\n\n                    def geometric_progression(a, r, n):\n                        return a * (1.0 - r ** n) / (1.0 - r)\n\n                    left, right = 1.01, 1.5\n                    while right - left > 1e-6:\n                        q = (left + right) / 2.0\n                        gp = geometric_progression(1, q, src_size // 2)\n                        if gp > dst_size // 2:\n                            right = q\n                        else:\n                            left = q\n\n                    # if q > 1.13492:\n                    #     q = 1.13492\n\n                    dis = []\n                    cur = 1\n                    for i in range(src_size // 2):\n                        dis.append(cur)\n                        cur += q ** (i + 1)\n\n                    r_ids = [-_ for _ in reversed(dis)]\n\n                    x = r_ids + [0] + dis\n                    y = r_ids + [0] + dis\n\n                    t = dst_size // 2.0\n                    dx = np.arange(-t, t + 0.1, 1.0)\n                    dy = np.arange(-t, t + 0.1, 1.0)\n                    if rank == 0:\n                        print(\"x = {}\".format(x))\n                        print(\"dx = {}\".format(dx))\n\n                    all_rel_pos_bias = []\n\n                    for i in range(num_attn_heads):\n                        z = rel_pos_bias[:, i].view(src_size, src_size).float().numpy()\n                        f = interpolate.interp2d(x, y, z, kind='cubic')\n                        all_rel_pos_bias.append(\n                            torch.Tensor(f(dx, dy)).contiguous().view(-1, 1).to(rel_pos_bias.device))\n\n                    rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1)\n                    new_rel_pos_bias = torch.cat((rel_pos_bias, extra_tokens), dim=0)\n                    state_dict[key] = new_rel_pos_bias\n\n\n    if 'pos_embed' in state_dict:\n        pos_embed_checkpoint = state_dict['pos_embed']\n        embedding_size = pos_embed_checkpoint.shape[-1]\n        num_patches = model.patch_embed.num_patches\n        num_extra_tokens = model.pos_embed.shape[-2] - num_patches\n        # height (== width) for the checkpoint position embedding\n        orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)\n        # height (== width) for the new position embedding\n        #new_size = int(num_patches ** 0.5)\n        new_size_w = model.patch_embed.num_patches_w\n        new_size_h = model.patch_embed.num_patches_h\n        # class_token and dist_token are kept unchanged\n        if orig_size != new_size_h or orig_size != new_size_w:\n            if rank == 0:\n                print(\"Position interpolate from %dx%d to %dx%d\" % (orig_size, orig_size, new_size_w, new_size_h))\n            extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]\n            # only the position tokens are interpolated\n            pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]\n            pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)\n            pos_tokens = torch.nn.functional.interpolate(\n                pos_tokens, size=(new_size_w, new_size_h), mode='bicubic', align_corners=False)\n            pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)\n            new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)\n            state_dict['pos_embed'] = new_pos_embed\n\n    # interpolate position bias table if needed\n    relative_position_bias_table_keys = [k for k in state_dict.keys() if \"relative_position_bias_table\" in k and  k in model.state_dict()]\n    for table_key in relative_position_bias_table_keys:\n        table_pretrained = state_dict[table_key]\n        table_current = model.state_dict()[table_key]\n        L1, nH1 = table_pretrained.size()\n        L2, nH2 = table_current.size()\n        if nH1 != nH2:\n            logger.warning(f\"Error in loading {table_key}, pass\")\n        else:\n            if L1 != L2:\n                S1 = int(L1 ** 0.5)\n                S2 = int(L2 ** 0.5)\n                table_pretrained_resized = F.interpolate(\n                     table_pretrained.permute(1, 0).view(1, nH1, S1, S1),\n                     size=(S2, S2), mode='bicubic')\n                state_dict[table_key] = table_pretrained_resized.view(nH2, L2).permute(1, 0)\n\n    # load state_dict\n    load_state_dict(model, state_dict, strict, logger)\n    return checkpoint\n\n\ndef weights_to_cpu(state_dict):\n    \"\"\"Copy a model state_dict to cpu.\n\n    Args:\n        state_dict (OrderedDict): Model weights on GPU.\n\n    Returns:\n        OrderedDict: Model weights on GPU.\n    \"\"\"\n    state_dict_cpu = OrderedDict()\n    for key, val in state_dict.items():\n        state_dict_cpu[key] = val.cpu()\n    return state_dict_cpu\n\n\ndef _save_to_state_dict(module, destination, prefix, keep_vars):\n    \"\"\"Saves module state to `destination` dictionary.\n\n    This method is modified from :meth:`torch.nn.Module._save_to_state_dict`.\n\n    Args:\n        module (nn.Module): The module to generate state_dict.\n        destination (dict): A dict where state will be stored.\n        prefix (str): The prefix for parameters and buffers used in this\n            module.\n    \"\"\"\n    for name, param in module._parameters.items():\n        if param is not None:\n            destination[prefix + name] = param if keep_vars else param.detach()\n    for name, buf in module._buffers.items():\n        # remove check of _non_persistent_buffers_set to allow nn.BatchNorm2d\n        if buf is not None:\n            destination[prefix + name] = buf if keep_vars else buf.detach()\n\n\ndef get_state_dict(module, destination=None, prefix='', keep_vars=False):\n    \"\"\"Returns a dictionary containing a whole state of the module.\n\n    Both parameters and persistent buffers (e.g. running averages) are\n    included. Keys are corresponding parameter and buffer names.\n\n    This method is modified from :meth:`torch.nn.Module.state_dict` to\n    recursively check parallel module in case that the model has a complicated\n    structure, e.g., nn.Module(nn.Module(DDP)).\n\n    Args:\n        module (nn.Module): The module to generate state_dict.\n        destination (OrderedDict): Returned dict for the state of the\n            module.\n        prefix (str): Prefix of the key.\n        keep_vars (bool): Whether to keep the variable property of the\n            parameters. Default: False.\n\n    Returns:\n        dict: A dictionary containing a whole state of the module.\n    \"\"\"\n    # recursively check parallel module in case that the model has a\n    # complicated structure, e.g., nn.Module(nn.Module(DDP))\n    if is_module_wrapper(module):\n        module = module.module\n\n    # below is the same as torch.nn.Module.state_dict()\n    if destination is None:\n        destination = OrderedDict()\n        destination._metadata = OrderedDict()\n    destination._metadata[prefix[:-1]] = local_metadata = dict(\n        version=module._version)\n    _save_to_state_dict(module, destination, prefix, keep_vars)\n    for name, child in module._modules.items():\n        if child is not None:\n            get_state_dict(\n                child, destination, prefix + name + '.', keep_vars=keep_vars)\n    for hook in module._state_dict_hooks.values():\n        hook_result = hook(module, destination, prefix, local_metadata)\n        if hook_result is not None:\n            destination = hook_result\n    return destination\n\n\ndef save_checkpoint(model, filename, optimizer=None, meta=None):\n    \"\"\"Save checkpoint to file.\n\n    The checkpoint will have 3 fields: ``meta``, ``state_dict`` and\n    ``optimizer``. By default ``meta`` will contain version and time info.\n\n    Args:\n        model (Module): Module whose params are to be saved.\n        filename (str): Checkpoint filename.\n        optimizer (:obj:`Optimizer`, optional): Optimizer to be saved.\n        meta (dict, optional): Metadata to be saved in checkpoint.\n    \"\"\"\n    if meta is None:\n        meta = {}\n    elif not isinstance(meta, dict):\n        raise TypeError(f'meta must be a dict or None, but got {type(meta)}')\n    meta.update(mmcv_version=mmcv.__version__, time=time.asctime())\n\n    if is_module_wrapper(model):\n        model = model.module\n\n    if hasattr(model, 'CLASSES') and model.CLASSES is not None:\n        # save class name to the meta\n        meta.update(CLASSES=model.CLASSES)\n\n    checkpoint = {\n        'meta': meta,\n        'state_dict': weights_to_cpu(get_state_dict(model))\n    }\n    # save optimizer state dict in the checkpoint\n    if isinstance(optimizer, Optimizer):\n        checkpoint['optimizer'] = optimizer.state_dict()\n    elif isinstance(optimizer, dict):\n        checkpoint['optimizer'] = {}\n        for name, optim in optimizer.items():\n            checkpoint['optimizer'][name] = optim.state_dict()\n\n    if filename.startswith('pavi://'):\n        try:\n            from pavi import modelscloud\n            from pavi.exception import NodeNotFoundError\n        except ImportError:\n            raise ImportError(\n                'Please install pavi to load checkpoint from modelcloud.')\n        model_path = filename[7:]\n        root = modelcloud.Folder()\n        model_dir, model_name = osp.split(model_path)\n        try:\n            model = modelcloud.get(model_dir)\n        except NodeNotFoundError:\n            model = root.create_training_model(model_dir)\n        with TemporaryDirectory() as tmp_dir:\n            checkpoint_file = osp.join(tmp_dir, model_name)\n            with open(checkpoint_file, 'wb') as f:\n                torch.save(checkpoint, f)\n                f.flush()\n            model.create_file(checkpoint_file, name=model_name)\n    else:\n        mmcv.mkdir_or_exist(osp.dirname(filename))\n        # immediately flush buffer\n        with open(filename, 'wb') as f:\n            torch.save(checkpoint, f)\n            f.flush()\n"
  },
  {
    "path": "downstream_tasks/detection/evaluation/object_detection/mmcv_custom/layer_decay_optimizer_constructor.py",
    "content": "# Copyright (c) ByteDance, Inc. and its affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the license found in the\n# LICENSE file in the root directory of this source tree.\n\n\"\"\"\nMostly copy-paste from BEiT library:\nhttps://github.com/microsoft/unilm/blob/master/beit/semantic_segmentation/mmcv_custom/layer_decay_optimizer_constructor.py\n\"\"\"\n\nimport json\n\nfrom mmcv.runner import OPTIMIZER_BUILDERS, DefaultOptimizerConstructor\nfrom mmcv.runner import get_dist_info\n\ndef get_num_layer_for_vit(var_name, num_max_layer):\n    if var_name in (\"backbone.cls_token\", \"backbone.mask_token\", \"backbone.pos_embed\"):\n        return 0\n    elif var_name.startswith(\"backbone.patch_embed\"):\n        return 0\n    elif var_name.startswith(\"backbone.blocks\"):\n        layer_id = int(var_name.split('.')[2])\n        return layer_id + 1\n    else:\n        return num_max_layer - 1\n\n\n@OPTIMIZER_BUILDERS.register_module()\nclass LayerDecayOptimizerConstructor(DefaultOptimizerConstructor):\n    def add_params(self, params, module, prefix='', is_dcn_module=None):\n        \"\"\"Add all parameters of module to the params list.\n        The parameters of the given module will be added to the list of param\n        groups, with specific rules defined by paramwise_cfg.\n        Args:\n            params (list[dict]): A list of param groups, it will be modified\n                in place.\n            module (nn.Module): The module to be added.\n            prefix (str): The prefix of the module\n            is_dcn_module (int|float|None): If the current module is a\n                submodule of DCN, `is_dcn_module` will be passed to\n                control conv_offset layer's learning rate. Defaults to None.\n        \"\"\"\n        parameter_groups = {}\n        print(self.paramwise_cfg)\n        num_layers = self.paramwise_cfg.get('num_layers') + 2\n        layer_decay_rate = self.paramwise_cfg.get('layer_decay_rate')\n        print(\"Build LayerDecayOptimizerConstructor %f - %d\" % (layer_decay_rate, num_layers))\n        weight_decay = self.base_wd\n\n        for name, param in module.named_parameters():\n            if not param.requires_grad:\n                continue  # frozen weights\n            if len(param.shape) == 1 or name.endswith(\".bias\") or name in ('pos_embed', 'cls_token'):\n                group_name = \"no_decay\"\n                this_weight_decay = 0.\n            else:\n                group_name = \"decay\"\n                this_weight_decay = weight_decay\n\n            layer_id = get_num_layer_for_vit(name, num_layers)\n            group_name = \"layer_%d_%s\" % (layer_id, group_name)\n\n            if group_name not in parameter_groups:\n                scale = layer_decay_rate ** (num_layers - layer_id - 1)\n\n                parameter_groups[group_name] = {\n                    \"weight_decay\": this_weight_decay,\n                    \"params\": [],\n                    \"param_names\": [], \n                    \"lr_scale\": scale, \n                    \"group_name\": group_name, \n                    \"lr\": scale * self.base_lr, \n                }\n\n            parameter_groups[group_name][\"params\"].append(param)\n            parameter_groups[group_name][\"param_names\"].append(name)\n        rank, _ = get_dist_info()\n        if rank == 0:\n            to_display = {}\n            for key in parameter_groups:\n                to_display[key] = {\n                    \"param_names\": parameter_groups[key][\"param_names\"], \n                    \"lr_scale\": parameter_groups[key][\"lr_scale\"], \n                    \"lr\": parameter_groups[key][\"lr\"], \n                    \"weight_decay\": parameter_groups[key][\"weight_decay\"], \n                }\n            print(\"Param groups = %s\" % json.dumps(to_display, indent=2))\n        \n        # state_dict = module.state_dict()\n        # for group_name in parameter_groups:\n        #     group = parameter_groups[group_name]\n        #     for name in group[\"param_names\"]:\n        #         group[\"params\"].append(state_dict[name])\n        params.extend(parameter_groups.values())\n"
  },
  {
    "path": "downstream_tasks/detection/evaluation/object_detection/mmcv_custom/prepare_rpe.py",
    "content": "import torch\n\nimport numpy as np\nfrom scipy import interpolate\n\nfrom mmcv.runner import get_dist_info\nimport torch.nn as nn\n\ndef rpe_index(window_size):\n    num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3\n\n    # get pair-wise relative position index for each token inside the window\n    coords_h = torch.arange(window_size[0])\n    coords_w = torch.arange(window_size[1])\n    coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww\n    coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww\n    relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww\n    relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2\n    relative_coords[:, :, 0] += window_size[0] - 1  # shift to start from 0\n    relative_coords[:, :, 1] += window_size[1] - 1\n    relative_coords[:, :, 0] *= 2 * window_size[1] - 1\n    relative_position_index = \\\n        torch.zeros(size=(window_size[0] * window_size[1] + 1, ) * 2, dtype=relative_coords.dtype)\n    relative_position_index[1:, 1:] = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww\n    relative_position_index[0, 0:] = num_relative_distance - 3\n    relative_position_index[0:, 0] = num_relative_distance - 2\n    relative_position_index[0, 0] = num_relative_distance - 1\n\n    return relative_position_index\n\ndef prepare_rpe(rel_pos_bias, src_patch_shape, dst_patch_shape):\n    src_num_pos, num_attn_heads = rel_pos_bias.size()   # 732\n\n    rank, _ = get_dist_info()\n    dst_num_pos = (dst_patch_shape[0]*2 -1) * (dst_patch_shape[1]*2 -1) + 3 \n\n    if dst_patch_shape[0] != src_patch_shape[0] or dst_patch_shape[1] != src_patch_shape[1]:\n\n        num_extra_tokens = dst_num_pos - (dst_patch_shape[0] * 2 - 1) * (dst_patch_shape[1] * 2 - 1)\n\n        # src_size = int((src_num_pos - num_extra_tokens) ** 0.5)   # 27 \n        src_size_0, src_size_1 = src_patch_shape[0] * 2 - 1, src_patch_shape[1]*2 -1\n        extra_tokens = rel_pos_bias[-num_extra_tokens:, :]\n        rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :]\n\n        dst_size_0 = dst_patch_shape[0] * 2 - 1   # 42\n        dst_size_1 = dst_patch_shape[1] * 2 - 1   # 68\n\n        dim = rel_pos_bias.shape[-1]\n        rel_pos_bias = rel_pos_bias.reshape(1 , src_size_0, src_size_1, dim).permute(0, 3, 1, 2)\n        new_rel_pos_bias = nn.functional.interpolate(rel_pos_bias, scale_factor=(dst_size_0 / src_size_0, dst_size_1 / dst_size_1), mode='bicubic',) \n        new_rel_pos_bias = new_rel_pos_bias.permute(0, 2, 3, 1).view(1, -1, dim).squeeze(0)\n        new_rel_pos_bias = torch.cat((new_rel_pos_bias, extra_tokens), dim=0)\n    else:\n        new_rel_pos_bias = rel_pos_bias\n    # get rpe_index\n    relative_position_index = rpe_index(dst_patch_shape)\n    new_rel_pos_bias = new_rel_pos_bias[relative_position_index.view(-1)].view(\n                    dst_patch_shape[0] * dst_patch_shape[1] + 1,\n                    dst_patch_shape[0] * dst_patch_shape[1] + 1, -1)  # Wh*Ww,Wh*Ww,nH\n    new_rel_pos_bias = new_rel_pos_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww\n    return new_rel_pos_bias\n"
  },
  {
    "path": "downstream_tasks/detection/evaluation/object_detection/mmcv_custom/register_backbone.py",
    "content": "# Copyright (c) ByteDance, Inc. and its affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the license found in the\n# LICENSE file in the root directory of this source tree.\n\nimport os\nimport torch\nimport torch.nn as nn\nimport torch.utils.checkpoint as checkpoint\n\nfrom mmcv_custom import load_checkpoint\nfrom mmdet.utils import get_root_logger\nfrom mmdet.models.builder import BACKBONES\nfrom models import VisionTransformer\nfrom .prepare_rpe import prepare_rpe\nimport time\n\nclass PatchEmbed(nn.Module):\n    \"\"\" Image to Patch Embedding\n    \"\"\"\n    def __init__(self, img_size=[224, 224], patch_size=16, in_chans=3, embed_dim=768):\n        super().__init__()\n        self.num_patches_w = img_size[0] // patch_size\n        self.num_patches_h = img_size[1] // patch_size\n\n        num_patches = self.num_patches_w * self.num_patches_h\n        self.patch_shape = (img_size[0] // patch_size, img_size[1] // patch_size)\n        self.img_size = img_size\n        self.patch_size = patch_size\n        self.num_patches = num_patches\n        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)\n            \n    def forward(self, x, mask=None):\n        B, C, H, W = x.shape\n        return self.proj(x)\n\n@BACKBONES.register_module()\nclass VisionTransformer(VisionTransformer):\n    def __init__(self,\n                 img_size,\n                 patch_size,\n                 embed_dim,\n                 in_chans=3,\n                 with_fpn=True,\n                 frozen_stages=-1,\n                 out_indices=[3, 5, 7, 11],\n                 out_with_norm=False,\n                 use_checkpoint=False,\n                 **kwargs):\n        super(VisionTransformer, self).__init__(\n            img_size=img_size,\n            patch_size=patch_size,\n            in_chans=in_chans,\n            embed_dim=embed_dim, \n            **kwargs)\n        \n        # support non-square image as input\n        if len(img_size) == 1:\n            img_size = img_size * 2\n        self.patch_embed = PatchEmbed(\n            img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)\n        num_patches = self.patch_embed.num_patches\n        if self.use_abs_pos_emb:\n            self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))\n        elif self.use_sincos_pos_emb:\n            self.pos_embed = self.build_2d_sincos_position_embedding(embed_dim)\n        else:\n            self.pos_embed = None\n\n        \n        self.patch_size = patch_size\n        self.with_fpn = with_fpn\n        self.frozen_stages = frozen_stages\n        self.out_indices = out_indices\n        self.use_checkpoint = use_checkpoint\n\n        if not out_with_norm:\n            self.norm = nn.Identity()\n\n        if with_fpn and patch_size == 16:\n            self.fpn1 = nn.Sequential(\n                nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),\n                nn.SyncBatchNorm(embed_dim),\n                nn.GELU(),\n                nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),\n            )\n\n            self.fpn2 = nn.Sequential(\n                nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),\n            )\n\n            self.fpn3 = nn.Identity()\n\n            self.fpn4 = nn.MaxPool2d(kernel_size=2, stride=2)\n        elif with_fpn and patch_size == 8:\n            self.fpn1 = nn.Sequential(\n                nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),\n            )\n\n            self.fpn2 = nn.Identity()\n\n            self.fpn3 = nn.Sequential(\n                nn.MaxPool2d(kernel_size=2, stride=2),\n            )\n\n            self.fpn4 = nn.Sequential(\n                nn.MaxPool2d(kernel_size=4, stride=4),\n            )\n        else:\n            logger = get_root_logger()\n            logger.info('Build model without FPN.')\n\n\n    def build_2d_sincos_position_embedding(self, embed_dim=768, temperature=10000., decode=False):\n        h, w = self.patch_embed.patch_shape \n        grid_w = torch.arange(w, dtype=torch.float32)\n        grid_h = torch.arange(h, dtype=torch.float32)\n        grid_w, grid_h = torch.meshgrid(grid_w, grid_h)\n        assert embed_dim % 4 == 0, 'Embed dimension must be divisible by 4 for 2D sin-cos position embedding'\n        pos_dim = embed_dim // 4\n        omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim\n        omega = 1. / (temperature ** omega)\n        out_w = torch.einsum('m,d->md', [grid_w.flatten(), omega])\n        out_h = torch.einsum('m,d->md', [grid_h.flatten(), omega])\n        pos_emb = torch.cat([torch.sin(out_w), torch.cos(out_w), torch.sin(out_h), torch.cos(out_h)], dim=1)[None, :, :]\n\n        pe_token = torch.zeros([1, 1, embed_dim], dtype=torch.float32)\n        pos_embed = nn.Parameter(torch.cat([pe_token, pos_emb], dim=1))\n        pos_embed.requires_grad = False\n        return pos_embed\n\n    def train(self, mode=True):\n        \"\"\"Convert the model into training mode while keep layers freezed.\"\"\"\n        super(VisionTransformer, self).train(mode)\n        self._freeze_stages()\n        if self.pos_embed is not None:\n            if self.pos_embed.requires_grad:\n                print(\"=================pos_embed update ================\")\n            else:\n                print(\"=================pos_embed static ================\")\n            \n\n    def _freeze_stages(self):\n        if self.frozen_stages >= 0:\n            self.patch_embed.eval()\n            for param in self.patch_embed.parameters():\n                param.requires_grad = False\n            self.cls_token.requires_grad = False\n            if self.pos_embed is not None and self.use_sincos_pos_emb == True:\n                self.pos_embed.requires_grad = False\n            self.pos_drop.eval()\n\n        for i in range(1, self.frozen_stages + 1):\n            \n            if i  == len(self.blocks):\n                norm_layer = getattr(self, 'norm') #f'norm{i-1}')\n                norm_layer.eval()\n                for param in norm_layer.parameters():\n                    param.requires_grad = False\n\n            m = self.blocks[i - 1]\n            m.eval()\n            for param in m.parameters():\n                param.requires_grad = False\n            \n    def init_weights(self, pretrained=None):\n        \"\"\"Initialize the weights in backbone.\n        Args:\n            pretrained (str, optional): Path to pre-trained weights.\n                Defaults to None.\n        \"\"\"\n\n        if isinstance(pretrained, str):\n            self.apply(self._init_weights)\n            logger = get_root_logger()\n            if  os.path.isfile(pretrained):\n                load_checkpoint(self, pretrained, strict=False, logger=logger)\n            else:\n                logger.info(f\"checkpoint path {pretrained} is invalid, we skip it and initialize net randomly\")\n        elif pretrained is None:\n            self.apply(self._init_weights)\n        else:\n            raise TypeError('pretrained must be a str or None')\n\n    def interpolate_pos_encoding(self, x, w, h):\n        npatch = x.shape[1] - 1\n        N = self.pos_embed.shape[1] - 1\n        w0 = w // self.patch_embed.patch_size\n        h0 = h // self.patch_embed.patch_size\n        if npatch == N and w0 == self.patch_embed.num_patches_w and h0 == self.patch_embed.num_patches_h:\n            return self.pos_embed\n        class_pos_embed = self.pos_embed[:, 0]\n        patch_pos_embed = self.pos_embed[:, 1:]\n        dim = x.shape[-1]\n        # we add a small number to avoid floating point error in the interpolation\n        # see discussion at https://github.com/facebookresearch/dino/issues/8\n        w0, h0 = w0 + 0.1, h0 + 0.1\n        \n        tmp=patch_pos_embed.reshape(1, self.patch_embed.num_patches_w, self.patch_embed.num_patches_h, dim).permute(0, 3, 1, 2)\n        patch_pos_embed = nn.functional.interpolate(\n            patch_pos_embed.reshape(1, self.patch_embed.num_patches_w, self.patch_embed.num_patches_h, dim).permute(0, 3, 1, 2),\n            scale_factor=(w0 / self.patch_embed.num_patches_w, h0 / self.patch_embed.num_patches_h),\n            mode='bicubic',\n        )\n        assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]\n        patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)\n        return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)\n\n    def forward(self, x):\n        B, _, H, W = x.shape\n        Hp, Wp = H // self.patch_size, W // self.patch_size\n            \n        x = self.prepare_tokens(x)\n        features = []\n        \n        time_begin = time.time()\n        if self.relative_position_bias_table is None:\n            x_rpe = None\n        else:\n            dst_rpe_shape = (Wp, Hp) if H <= W else(Hp, Wp) \n            x_rpe = prepare_rpe(self.relative_position_bias_table, self.patch_embed.patch_shape, dst_rpe_shape)\n        for i, blk in enumerate(self.blocks):\n            if self.use_checkpoint:\n                x = checkpoint.checkpoint(blk, x, x_rpe)\n            else:\n                x = blk(x, x_rpe)\n            if i in self.out_indices:\n                xp = self.norm(x[:, 1:, :]).permute(0, 2, 1).reshape(B, -1, Hp, Wp)       \n                features.append(xp.contiguous())\n        time_backbone = time.time()\n        \n        if self.with_fpn:\n            ops = [self.fpn1, self.fpn2, self.fpn3, self.fpn4]\n            for i in range(len(features)):\n                features[i] = ops[i](features[i])\n        time_end = time.time()\n        return tuple(features)\n"
  },
  {
    "path": "downstream_tasks/detection/evaluation/object_detection/mmcv_custom/runner/__init__.py",
    "content": "\n# Copyright (c) Open-MMLab. All rights reserved.\nfrom .checkpoint import save_checkpoint\nfrom .epoch_based_runner import EpochBasedRunnerAmp\n\n\n__all__ = [\n    'EpochBasedRunnerAmp', 'save_checkpoint'\n]"
  },
  {
    "path": "downstream_tasks/detection/evaluation/object_detection/mmcv_custom/runner/checkpoint.py",
    "content": "# Copyright (c) ByteDance, Inc. and its affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the license found in the\n# LICENSE file in the root directory of this source tree.\n\n\"\"\"\nCopy-paste from mmcv library:\nhttps://github.com/open-mmlab/mmcv/\n\"\"\"\n\nimport os.path as osp\nimport time\nimport torch\nimport mmcv\ntry:\n    import apex\nexcept:\n    print('apex is not installed')\n\nfrom tempfile import TemporaryDirectory\nfrom torch.optim import Optimizer\nfrom mmcv.parallel import is_module_wrapper\nfrom mmcv.runner.checkpoint import weights_to_cpu, get_state_dict\n\ndef save_checkpoint(model, filename, optimizer=None, meta=None):\n    \"\"\"Save checkpoint to file.\n    The checkpoint will have 4 fields: ``meta``, ``state_dict`` and\n    ``optimizer``, ``amp``. By default ``meta`` will contain version\n    and time info.\n    Args:\n        model (Module): Module whose params are to be saved.\n        filename (str): Checkpoint filename.\n        optimizer (:obj:`Optimizer`, optional): Optimizer to be saved.\n        meta (dict, optional): Metadata to be saved in checkpoint.\n    \"\"\"\n    if meta is None:\n        meta = {}\n    elif not isinstance(meta, dict):\n        raise TypeError(f'meta must be a dict or None, but got {type(meta)}')\n    meta.update(mmcv_version=mmcv.__version__, time=time.asctime())\n\n    if is_module_wrapper(model):\n        model = model.module\n\n    if hasattr(model, 'CLASSES') and model.CLASSES is not None:\n        # save class name to the meta\n        meta.update(CLASSES=model.CLASSES)\n\n    checkpoint = {\n        'meta': meta,\n        'state_dict': weights_to_cpu(get_state_dict(model))\n    }\n    # save optimizer state dict in the checkpoint\n    if isinstance(optimizer, Optimizer):\n        checkpoint['optimizer'] = optimizer.state_dict()\n    elif isinstance(optimizer, dict):\n        checkpoint['optimizer'] = {}\n        for name, optim in optimizer.items():\n            checkpoint['optimizer'][name] = optim.state_dict()\n\n    # save amp state dict in the checkpoint\n    checkpoint['amp'] = apex.amp.state_dict()\n\n    if filename.startswith('pavi://'):\n        try:\n            from pavi import modelscloud\n            from pavi.exception import NodeNotFoundError\n        except ImportError:\n            raise ImportError(\n                'Please install pavi to load checkpoint from modelcloud.')\n        model_path = filename[7:]\n        root = modelcloud.Folder()\n        model_dir, model_name = osp.split(model_path)\n        try:\n            model = modelcloud.get(model_dir)\n        except NodeNotFoundError:\n            model = root.create_training_model(model_dir)\n        with TemporaryDirectory() as tmp_dir:\n            checkpoint_file = osp.join(tmp_dir, model_name)\n            with open(checkpoint_file, 'wb') as f:\n                torch.save(checkpoint, f)\n                f.flush()\n            model.create_file(checkpoint_file, name=model_name)\n    else:\n        mmcv.mkdir_or_exist(osp.dirname(filename))\n        # immediately flush buffer\n        with open(filename, 'wb') as f:\n            torch.save(checkpoint, f)\n            f.flush()\n"
  },
  {
    "path": "downstream_tasks/detection/evaluation/object_detection/mmcv_custom/runner/epoch_based_runner.py",
    "content": "# Copyright (c) ByteDance, Inc. and its affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the license found in the\n# LICENSE file in the root directory of this source tree.\n\n\"\"\"\nCopy-paste from mmcv library:\nhttps://github.com/open-mmlab/mmcv/\n\"\"\"\n\nimport os.path as osp\nimport platform\nimport shutil\nimport torch\nimport mmcv\ntry:\n    import apex\nexcept:\n    print('apex is not installed')\n\nfrom torch.optim import Optimizer\nfrom mmcv.runner import RUNNERS, EpochBasedRunner\nfrom .checkpoint import save_checkpoint\n\n@RUNNERS.register_module()\nclass EpochBasedRunnerAmp(EpochBasedRunner):\n    \"\"\"Epoch-based Runner with AMP support.\n    This runner train models epoch by epoch.\n    \"\"\"\n\n    def save_checkpoint(self,\n                        out_dir,\n                        filename_tmpl='epoch_{}.pth',\n                        save_optimizer=True,\n                        meta=None,\n                        create_symlink=True):\n        \"\"\"Save the checkpoint.\n        Args:\n            out_dir (str): The directory that checkpoints are saved.\n            filename_tmpl (str, optional): The checkpoint filename template,\n                which contains a placeholder for the epoch number.\n                Defaults to 'epoch_{}.pth'.\n            save_optimizer (bool, optional): Whether to save the optimizer to\n                the checkpoint. Defaults to True.\n            meta (dict, optional): The meta information to be saved in the\n                checkpoint. Defaults to None.\n            create_symlink (bool, optional): Whether to create a symlink\n                \"latest.pth\" to point to the latest checkpoint.\n                Defaults to True.\n        \"\"\"\n        if meta is None:\n            meta = dict(epoch=self.epoch + 1, iter=self.iter)\n        elif isinstance(meta, dict):\n            meta.update(epoch=self.epoch + 1, iter=self.iter)\n        else:\n            raise TypeError(\n                f'meta should be a dict or None, but got {type(meta)}')\n        if self.meta is not None:\n            meta.update(self.meta)\n\n        filename = filename_tmpl.format(self.epoch + 1)\n        filepath = osp.join(out_dir, filename)\n        optimizer = self.optimizer if save_optimizer else None\n        save_checkpoint(self.model, filepath, optimizer=optimizer, meta=meta)\n        # in some environments, `os.symlink` is not supported, you may need to\n        # set `create_symlink` to False\n        if create_symlink:\n            dst_file = osp.join(out_dir, 'latest.pth')\n            if platform.system() != 'Windows':\n                mmcv.symlink(filename, dst_file)\n            else:\n                shutil.copy(filepath, dst_file)\n\n    def resume(self,\n               checkpoint,\n               resume_optimizer=True,\n               map_location='default'):\n        if map_location == 'default':\n            if torch.cuda.is_available():\n                device_id = torch.cuda.current_device()\n                checkpoint = self.load_checkpoint(\n                    checkpoint,\n                    map_location=lambda storage, loc: storage.cuda(device_id))\n            else:\n                checkpoint = self.load_checkpoint(checkpoint)\n        else:\n            checkpoint = self.load_checkpoint(\n                checkpoint, map_location=map_location)\n\n        self._epoch = checkpoint['meta']['epoch']\n        self._iter = checkpoint['meta']['iter']\n        if 'optimizer' in checkpoint and resume_optimizer:\n            if isinstance(self.optimizer, Optimizer):\n                self.optimizer.load_state_dict(checkpoint['optimizer'])\n            elif isinstance(self.optimizer, dict):\n                for k in self.optimizer.keys():\n                    self.optimizer[k].load_state_dict(\n                        checkpoint['optimizer'][k])\n            else:\n                raise TypeError(\n                    'Optimizer should be dict or torch.optim.Optimizer '\n                    f'but got {type(self.optimizer)}')\n\n        if 'amp' in checkpoint:\n            apex.amp.load_state_dict(checkpoint['amp'])\n            self.logger.info('load amp state dict')\n\n        self.logger.info('resumed epoch %d, iter %d', self.epoch, self.iter)"
  },
  {
    "path": "downstream_tasks/detection/evaluation/object_detection/test.py",
    "content": "# Copyright (c) ByteDance, Inc. and its affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the license found in the\n# LICENSE file in the root directory of this source tree.\n\n\"\"\"\nMostly copy-paste from mmdetection library:\nhttps://github.com/open-mmlab/mmdetection/blob/master/tools/test.py\n\"\"\"\n\nimport argparse\nimport os\nimport warnings\nimport mmcv\nimport torch\n\nfrom mmcv import Config, DictAction\nfrom mmcv.cnn import fuse_conv_bn\nfrom mmcv.parallel import MMDataParallel, MMDistributedDataParallel\nfrom mmcv.runner import (get_dist_info, init_dist, load_checkpoint,\n                         wrap_fp16_model)\nfrom mmdet.apis import multi_gpu_test, single_gpu_test\nfrom mmdet.datasets import (build_dataloader, build_dataset,\n                            replace_ImageToTensor)\nfrom mmdet.models import build_detector\n\ndef parse_args():\n    parser = argparse.ArgumentParser(\n        description='MMDet test (and eval) a model')\n    parser.add_argument('config', help='test config file path')\n    parser.add_argument('checkpoint', help='checkpoint file')\n    parser.add_argument('--out', help='output result file in pickle format')\n    parser.add_argument(\n        '--fuse-conv-bn',\n        action='store_true',\n        help='Whether to fuse conv and bn, this will slightly increase'\n        'the inference speed')\n    parser.add_argument(\n        '--format-only',\n        action='store_true',\n        help='Format the output results without perform evaluation. It is'\n        'useful when you want to format the result to a specific format and '\n        'submit it to the test server')\n    parser.add_argument(\n        '--eval',\n        type=str,\n        nargs='+',\n        help='evaluation metrics, which depends on the dataset, e.g., \"bbox\",'\n        ' \"segm\", \"proposal\" for COCO, and \"mAP\", \"recall\" for PASCAL VOC')\n    parser.add_argument('--show', action='store_true', help='show results')\n    parser.add_argument(\n        '--show-dir', help='directory where painted images will be saved')\n    parser.add_argument(\n        '--show-score-thr',\n        type=float,\n        default=0.3,\n        help='score threshold (default: 0.3)')\n    parser.add_argument(\n        '--gpu-collect',\n        action='store_true',\n        help='whether to use gpu to collect results.')\n    parser.add_argument(\n        '--tmpdir',\n        help='tmp directory used for collecting results from multiple '\n        'workers, available when gpu-collect is not specified')\n    parser.add_argument(\n        '--cfg-options',\n        nargs='+',\n        action=DictAction,\n        help='override some settings in the used config, the key-value pair '\n        'in xxx=yyy format will be merged into config file. If the value to '\n        'be overwritten is a list, it should be like key=\"[a,b]\" or key=a,b '\n        'It also allows nested list/tuple values, e.g. key=\"[(a,b),(c,d)]\" '\n        'Note that the quotation marks are necessary and that no white space '\n        'is allowed.')\n    parser.add_argument(\n        '--options',\n        nargs='+',\n        action=DictAction,\n        help='custom options for evaluation, the key-value pair in xxx=yyy '\n        'format will be kwargs for dataset.evaluate() function (deprecate), '\n        'change to --eval-options instead.')\n    parser.add_argument(\n        '--eval-options',\n        nargs='+',\n        action=DictAction,\n        help='custom options for evaluation, the key-value pair in xxx=yyy '\n        'format will be kwargs for dataset.evaluate() function')\n    parser.add_argument(\n        '--launcher',\n        choices=['none', 'pytorch', 'slurm', 'mpi'],\n        default='none',\n        help='job launcher')\n    parser.add_argument('--local_rank', type=int, default=0)\n    args = parser.parse_args()\n    if 'LOCAL_RANK' not in os.environ:\n        os.environ['LOCAL_RANK'] = str(args.local_rank)\n\n    if args.options and args.eval_options:\n        raise ValueError(\n            '--options and --eval-options cannot be both '\n            'specified, --options is deprecated in favor of --eval-options')\n    if args.options:\n        warnings.warn('--options is deprecated in favor of --eval-options')\n        args.eval_options = args.options\n    return args\n\n\ndef main():\n    args = parse_args()\n\n    assert args.out or args.eval or args.format_only or args.show \\\n        or args.show_dir, \\\n        ('Please specify at least one operation (save/eval/format/show the '\n         'results / save the results) with the argument \"--out\", \"--eval\"'\n         ', \"--format-only\", \"--show\" or \"--show-dir\"')\n\n    if args.eval and args.format_only:\n        raise ValueError('--eval and --format_only cannot be both specified')\n\n    if args.out is not None and not args.out.endswith(('.pkl', '.pickle')):\n        raise ValueError('The output file must be a pkl file.')\n\n    cfg = Config.fromfile(args.config)\n    if args.cfg_options is not None:\n        cfg.merge_from_dict(args.cfg_options)\n    # import modules from string list.\n    if cfg.get('custom_imports', None):\n        from mmcv.utils import import_modules_from_strings\n        import_modules_from_strings(**cfg['custom_imports'])\n    # set cudnn_benchmark\n    if cfg.get('cudnn_benchmark', False):\n        torch.backends.cudnn.benchmark = True\n    cfg.model.pretrained = None\n    if cfg.model.get('neck'):\n        if isinstance(cfg.model.neck, list):\n            for neck_cfg in cfg.model.neck:\n                if neck_cfg.get('rfp_backbone'):\n                    if neck_cfg.rfp_backbone.get('pretrained'):\n                        neck_cfg.rfp_backbone.pretrained = None\n        elif cfg.model.neck.get('rfp_backbone'):\n            if cfg.model.neck.rfp_backbone.get('pretrained'):\n                cfg.model.neck.rfp_backbone.pretrained = None\n\n    # in case the test dataset is concatenated\n    samples_per_gpu = 1\n    if isinstance(cfg.data.test, dict):\n        cfg.data.test.test_mode = True\n        samples_per_gpu = cfg.data.test.pop('samples_per_gpu', 1)\n        if samples_per_gpu > 1:\n            # Replace 'ImageToTensor' to 'DefaultFormatBundle'\n            cfg.data.test.pipeline = replace_ImageToTensor(\n                cfg.data.test.pipeline)\n    elif isinstance(cfg.data.test, list):\n        for ds_cfg in cfg.data.test:\n            ds_cfg.test_mode = True\n        samples_per_gpu = max(\n            [ds_cfg.pop('samples_per_gpu', 1) for ds_cfg in cfg.data.test])\n        if samples_per_gpu > 1:\n            for ds_cfg in cfg.data.test:\n                ds_cfg.pipeline = replace_ImageToTensor(ds_cfg.pipeline)\n\n    # init distributed env first, since logger depends on the dist info.\n    if args.launcher == 'none':\n        distributed = False\n    else:\n        distributed = True\n        init_dist(args.launcher, **cfg.dist_params)\n\n    # build the dataloader\n    dataset = build_dataset(cfg.data.test)\n    data_loader = build_dataloader(\n        dataset,\n        samples_per_gpu=samples_per_gpu,\n        workers_per_gpu=cfg.data.workers_per_gpu,\n        dist=distributed,\n        shuffle=False)\n\n    # build the model and load checkpoint\n    cfg.model.train_cfg = None\n    model = build_detector(cfg.model, test_cfg=cfg.get('test_cfg'))\n    fp16_cfg = cfg.get('fp16', None)\n    if fp16_cfg is not None:\n        wrap_fp16_model(model)\n    checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu')\n    if args.fuse_conv_bn:\n        model = fuse_conv_bn(model)\n    # old versions did not save class info in checkpoints, this walkaround is\n    # for backward compatibility\n    if 'CLASSES' in checkpoint.get('meta', {}):\n        model.CLASSES = checkpoint['meta']['CLASSES']\n    else:\n        model.CLASSES = dataset.CLASSES\n\n    if not distributed:\n        model = MMDataParallel(model, device_ids=[0])\n        outputs = single_gpu_test(model, data_loader, args.show, args.show_dir,\n                                  args.show_score_thr)\n    else:\n        model = MMDistributedDataParallel(\n            model.cuda(),\n            device_ids=[torch.cuda.current_device()],\n            broadcast_buffers=False)\n        outputs = multi_gpu_test(model, data_loader, args.tmpdir,\n                                 args.gpu_collect)\n\n    rank, _ = get_dist_info()\n    if rank == 0:\n        if args.out:\n            print(f'\\nwriting results to {args.out}')\n            mmcv.dump(outputs, args.out)\n        kwargs = {} if args.eval_options is None else args.eval_options\n        if args.format_only:\n            dataset.format_results(outputs, **kwargs)\n        if args.eval:\n            eval_kwargs = cfg.get('evaluation', {}).copy()\n            # hard-code way to remove EvalHook args\n            for key in [\n                    'interval', 'tmpdir', 'start', 'gpu_collect', 'save_best',\n                    'rule'\n            ]:\n                eval_kwargs.pop(key, None)\n            eval_kwargs.update(dict(metric=args.eval, **kwargs))\n            print(dataset.evaluate(outputs, **eval_kwargs))\n\n\nif __name__ == '__main__':\n    main()"
  },
  {
    "path": "downstream_tasks/detection/evaluation/object_detection/train.py",
    "content": "# Copyright (c) ByteDance, Inc. and its affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the license found in the\n# LICENSE file in the root directory of this source tree.\n\n\"\"\"\nMostly copy-paste from mmdetection library:\nhttps://github.com/open-mmlab/mmdetection/blob/master/tools/train.py\n\"\"\"\n\nimport argparse\nimport copy\nimport os\nimport os.path as osp\nimport time\nimport warnings\nimport mmcv\nimport torch\n\nfrom mmcv import Config, DictAction\nfrom mmcv.runner import get_dist_info, init_dist\nfrom mmcv.utils import get_git_hash\nfrom mmdet import __version__\nfrom mmdet.apis import set_random_seed, train_detector\nfrom mmdet.datasets import build_dataset\nfrom mmdet.models import build_detector\nfrom mmdet.utils import collect_env, get_root_logger\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description='Train a detector')\n    parser.add_argument('config', help='train config file path')\n    parser.add_argument('--work-dir', help='the dir to save logs and models')\n    parser.add_argument(\n        '--resume-from', help='the checkpoint file to resume from')\n    parser.add_argument(\n        '--no-validate',\n        action='store_true',\n        help='whether not to evaluate the checkpoint during training')\n    group_gpus = parser.add_mutually_exclusive_group()\n    group_gpus.add_argument(\n        '--gpus',\n        type=int,\n        help='number of gpus to use '\n        '(only applicable to non-distributed training)')\n    group_gpus.add_argument(\n        '--gpu-ids',\n        type=int,\n        nargs='+',\n        help='ids of gpus to use '\n        '(only applicable to non-distributed training)')\n    parser.add_argument('--seed', type=int, default=None, help='random seed')\n    parser.add_argument(\n        '--deterministic',\n        action='store_true',\n        help='whether to set deterministic options for CUDNN backend.')\n    parser.add_argument(\n        '--options',\n        nargs='+',\n        action=DictAction,\n        help='override some settings in the used config, the key-value pair '\n        'in xxx=yyy format will be merged into config file (deprecate), '\n        'change to --cfg-options instead.')\n    parser.add_argument(\n        '--cfg-options',\n        nargs='+',\n        action=DictAction,\n        help='override some settings in the used config, the key-value pair '\n        'in xxx=yyy format will be merged into config file. If the value to '\n        'be overwritten is a list, it should be like key=\"[a,b]\" or key=a,b '\n        'It also allows nested list/tuple values, e.g. key=\"[(a,b),(c,d)]\" '\n        'Note that the quotation marks are necessary and that no white space '\n        'is allowed.')\n    parser.add_argument(\n        '--launcher',\n        choices=['none', 'pytorch', 'slurm', 'mpi'],\n        default='none',\n        help='job launcher')\n    parser.add_argument('--local_rank', type=int, default=0)\n    args = parser.parse_args()\n    if 'LOCAL_RANK' not in os.environ:\n        os.environ['LOCAL_RANK'] = str(args.local_rank)\n\n    if args.options and args.cfg_options:\n        raise ValueError(\n            '--options and --cfg-options cannot be both '\n            'specified, --options is deprecated in favor of --cfg-options')\n    if args.options:\n        warnings.warn('--options is deprecated in favor of --cfg-options')\n        args.cfg_options = args.options\n\n    return args\n\n\ndef main():\n    args = parse_args()\n\n    cfg = Config.fromfile(args.config)\n    if args.cfg_options is not None:\n        cfg.merge_from_dict(args.cfg_options)\n    # import modules from string list.\n    if cfg.get('custom_imports', None):\n        from mmcv.utils import import_modules_from_strings\n        import_modules_from_strings(**cfg['custom_imports'])\n    # set cudnn_benchmark\n    if cfg.get('cudnn_benchmark', False):\n        torch.backends.cudnn.benchmark = True\n\n    # work_dir is determined in this priority: CLI > segment in file > filename\n    if args.work_dir is not None:\n        # update configs according to CLI args if args.work_dir is not None\n        cfg.work_dir = args.work_dir\n    elif cfg.get('work_dir', None) is None:\n        # use config filename as default work_dir if cfg.work_dir is None\n        cfg.work_dir = osp.join('./work_dirs',\n                                osp.splitext(osp.basename(args.config))[0])\n    if args.resume_from is not None:\n        cfg.resume_from = args.resume_from\n    if args.gpu_ids is not None:\n        cfg.gpu_ids = args.gpu_ids\n    else:\n        cfg.gpu_ids = range(1) if args.gpus is None else range(args.gpus)\n\n    # init distributed env first, since logger depends on the dist info.\n    if args.launcher == 'none':\n        distributed = False\n    else:\n        distributed = True\n        init_dist(args.launcher, **cfg.dist_params)\n        # re-set gpu_ids with distributed training mode\n        _, world_size = get_dist_info()\n        cfg.gpu_ids = range(world_size)\n\n    # create work_dir\n    mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))\n    # dump config\n    cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config)))\n    # init the logger before other steps\n    timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())\n    log_file = osp.join(cfg.work_dir, f'{timestamp}.log')\n    logger = get_root_logger(log_file=log_file, log_level=cfg.log_level)\n\n    # init the meta dict to record some important information such as\n    # environment info and seed, which will be logged\n    meta = dict()\n    # log env info\n    env_info_dict = collect_env()\n    env_info = '\\n'.join([(f'{k}: {v}') for k, v in env_info_dict.items()])\n    dash_line = '-' * 60 + '\\n'\n    logger.info('Environment info:\\n' + dash_line + env_info + '\\n' +\n                dash_line)\n    meta['env_info'] = env_info\n    meta['config'] = cfg.pretty_text\n    # log some basic info\n    logger.info(f'Distributed training: {distributed}')\n    logger.info(f'Config:\\n{cfg.pretty_text}')\n\n    # set random seeds\n    if args.seed is not None:\n        logger.info(f'Set random seed to {args.seed}, '\n                    f'deterministic: {args.deterministic}')\n        set_random_seed(args.seed, deterministic=args.deterministic)\n    cfg.seed = args.seed\n    meta['seed'] = args.seed\n    meta['exp_name'] = osp.basename(args.config)\n\n    model = build_detector(\n        cfg.model,\n        train_cfg=cfg.get('train_cfg'),\n        test_cfg=cfg.get('test_cfg'))\n\n    datasets = [build_dataset(cfg.data.train)]\n    if len(cfg.workflow) == 2:\n        val_dataset = copy.deepcopy(cfg.data.val)\n        val_dataset.pipeline = cfg.data.train.pipeline\n        datasets.append(build_dataset(val_dataset))\n    if cfg.checkpoint_config is not None:\n        # save mmdet version, config file content and class names in\n        # checkpoints as meta data\n        cfg.checkpoint_config.meta = dict(\n            mmdet_version=__version__ + get_git_hash()[:7],\n            CLASSES=datasets[0].CLASSES)\n    # add an attribute for visualization convenience\n    model.CLASSES = datasets[0].CLASSES\n    train_detector(\n        model,\n        datasets,\n        cfg,\n        distributed=distributed,\n        validate=(not args.no_validate),\n        timestamp=timestamp,\n        meta=meta)\n\n\nif __name__ == '__main__':\n    main()"
  },
  {
    "path": "downstream_tasks/detection/loader.py",
    "content": "# Copyright (c) ByteDance, Inc. and its affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the license found in the\n# LICENSE file in the root directory of this source tree.\n\nimport random\nimport math\nimport numpy as np\n\nfrom torchvision.datasets import ImageFolder\n\nclass ImageFolderInstance(ImageFolder):\n    def __getitem__(self, index):\n        img, target = super(ImageFolderInstance, self).__getitem__(index)\n        return img, target, index\n\nclass ImageFolderMask(ImageFolder):\n    def __init__(self, *args, patch_size, pred_ratio, pred_ratio_var, pred_aspect_ratio, \n                 pred_shape='block', pred_start_epoch=0, **kwargs):\n        super(ImageFolderMask, self).__init__(*args, **kwargs)\n        self.psz = patch_size\n        self.pred_ratio = pred_ratio[0] if isinstance(pred_ratio, list) and \\\n            len(pred_ratio) == 1 else pred_ratio\n        self.pred_ratio_var = pred_ratio_var[0] if isinstance(pred_ratio_var, list) and \\\n            len(pred_ratio_var) == 1 else pred_ratio_var\n        if isinstance(self.pred_ratio, list) and not isinstance(self.pred_ratio_var, list):\n            self.pred_ratio_var = [self.pred_ratio_var] * len(self.pred_ratio)\n        self.log_aspect_ratio = tuple(map(lambda x: math.log(x), pred_aspect_ratio))\n        self.pred_shape = pred_shape\n        self.pred_start_epoch = pred_start_epoch\n\n    def get_pred_ratio(self):\n        if hasattr(self, 'epoch') and self.epoch < self.pred_start_epoch:\n            return 0\n\n        if isinstance(self.pred_ratio, list):\n            pred_ratio = []\n            for prm, prv in zip(self.pred_ratio, self.pred_ratio_var):\n                assert prm >= prv\n                pr = random.uniform(prm - prv, prm + prv) if prv > 0 else prm\n                pred_ratio.append(pr)\n            pred_ratio = random.choice(pred_ratio)\n        else:\n            assert self.pred_ratio >= self.pred_ratio_var\n            pred_ratio = random.uniform(self.pred_ratio - self.pred_ratio_var, self.pred_ratio + \\\n                self.pred_ratio_var) if self.pred_ratio_var > 0 else self.pred_ratio\n        \n        return pred_ratio\n\n    def set_epoch(self, epoch):\n        self.epoch = epoch\n\n    def __getitem__(self, index):\n        output = super(ImageFolderMask, self).__getitem__(index)\n                \n        masks = []\n        for img in output[0]:\n            try:\n                H, W = img.shape[1] // self.psz, img.shape[2] // self.psz\n            except:\n                # skip non-image\n                continue\n            \n            high = self.get_pred_ratio() * H * W\n            \n            if self.pred_shape == 'block':\n                # following BEiT (https://arxiv.org/abs/2106.08254), see at\n                # https://github.com/microsoft/unilm/blob/b94ec76c36f02fb2b0bf0dcb0b8554a2185173cd/beit/masking_generator.py#L55\n                mask = np.zeros((H, W), dtype=bool)\n                mask_count = 0\n                while mask_count < high:\n                    max_mask_patches = high - mask_count\n\n                    delta = 0\n                    for attempt in range(10):\n                        low = (min(H, W) // 3) ** 2 \n                        target_area = random.uniform(low, max_mask_patches)\n                        aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))\n                        h = int(round(math.sqrt(target_area * aspect_ratio)))\n                        w = int(round(math.sqrt(target_area / aspect_ratio)))\n                        if w < W and h < H:\n                            top = random.randint(0, H - h)\n                            left = random.randint(0, W - w)\n\n                            num_masked = mask[top: top + h, left: left + w].sum()\n                            if 0 < h * w - num_masked <= max_mask_patches:\n                                for i in range(top, top + h):\n                                    for j in range(left, left + w):\n                                        if mask[i, j] == 0:\n                                            mask[i, j] = 1\n                                            delta += 1\n\n                        if delta > 0:\n                            break\n\n                    if delta == 0:\n                        break\n                    else:\n                        mask_count += delta\n            \n            elif self.pred_shape == 'rand':\n                mask = np.hstack([\n                    np.zeros(H * W - int(high)),\n                    np.ones(int(high)),\n                ]).astype(bool)\n                np.random.shuffle(mask)\n                mask = mask.reshape(H, W)\n\n            else:\n                # no implementation\n                assert False\n\n            masks.append(mask)\n\n        return output + (masks,)"
  },
  {
    "path": "downstream_tasks/detection/models/__init__.py",
    "content": "from .vision_transformer import VisionTransformer, vit_tiny, vit_small, vit_base, vit_large\nfrom .swin_transformer import SwinTransformer, swin_tiny, swin_small, swin_base, swin_large"
  },
  {
    "path": "downstream_tasks/detection/models/head.py",
    "content": "# Copyright (c) ByteDance, Inc. and its affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the license found in the\n# LICENSE file in the root directory of this source tree.\n\nimport torch\nimport torch.nn as nn\nimport utils\n\nfrom utils import trunc_normal_\n\nclass CSyncBatchNorm(nn.SyncBatchNorm):\n    def __init__(self,\n                 *args,\n                 with_var=False,\n                 **kwargs):\n        super(CSyncBatchNorm, self).__init__(*args, **kwargs)\n        self.with_var = with_var\n\n    def forward(self, x):\n        # center norm\n        self.training = False\n        if not self.with_var:\n            self.running_var = torch.ones_like(self.running_var)\n        normed_x = super(CSyncBatchNorm, self).forward(x)\n        # udpate center\n        self.training = True\n        _ = super(CSyncBatchNorm, self).forward(x)\n        return normed_x\n\nclass PSyncBatchNorm(nn.SyncBatchNorm):\n    def __init__(self,\n                 *args,\n                 bunch_size,\n                 **kwargs):\n        procs_per_bunch = min(bunch_size, utils.get_world_size())\n        assert utils.get_world_size() % procs_per_bunch == 0\n        n_bunch = utils.get_world_size() // procs_per_bunch\n        #\n        ranks = list(range(utils.get_world_size()))\n        print('---ALL RANKS----\\n{}'.format(ranks))\n        rank_groups = [ranks[i*procs_per_bunch: (i+1)*procs_per_bunch] for i in range(n_bunch)]\n        print('---RANK GROUPS----\\n{}'.format(rank_groups))\n        process_groups = [torch.distributed.new_group(pids) for pids in rank_groups]\n        bunch_id = utils.get_rank() // procs_per_bunch\n        process_group = process_groups[bunch_id]\n        print('---CURRENT GROUP----\\n{}'.format(process_group))\n        super(PSyncBatchNorm, self).__init__(*args, process_group=process_group, **kwargs)\n\nclass CustomSequential(nn.Sequential):\n    bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm)\n\n    def forward(self, input):\n        for module in self:\n            dim = len(input.shape)\n            if isinstance(module, self.bn_types) and dim > 2:\n                perm = list(range(dim - 1)); perm.insert(1, dim - 1)\n                inv_perm = list(range(dim)) + [1]; inv_perm.pop(1)\n                input = module(input.permute(*perm)).permute(*inv_perm)\n            else:\n                input = module(input)\n        return input\n\nclass DINOHead(nn.Module):\n    def __init__(self, in_dim, out_dim, norm=None, act='gelu', last_norm=None, \n                 nlayers=3, hidden_dim=2048, bottleneck_dim=256, norm_last_layer=True, **kwargs):\n        super().__init__()\n        norm = self._build_norm(norm, hidden_dim)\n        last_norm = self._build_norm(last_norm, out_dim, affine=False, **kwargs)\n        act = self._build_act(act)\n\n        nlayers = max(nlayers, 1)\n        if nlayers == 1:\n            if bottleneck_dim > 0:\n                self.mlp = nn.Linear(in_dim, bottleneck_dim)\n            else:\n                self.mlp = nn.Linear(in_dim, out_dim)\n        else:\n            layers = [nn.Linear(in_dim, hidden_dim)]\n            if norm is not None:\n                layers.append(norm)\n            layers.append(act)\n            for _ in range(nlayers - 2):\n                layers.append(nn.Linear(hidden_dim, hidden_dim))\n                if norm is not None:\n                    layers.append(norm)\n                layers.append(act)\n            if bottleneck_dim > 0:\n                layers.append(nn.Linear(hidden_dim, bottleneck_dim))\n            else:\n                layers.append(nn.Linear(hidden_dim, out_dim))\n            self.mlp = CustomSequential(*layers)\n        self.apply(self._init_weights)\n        \n        if bottleneck_dim > 0:\n            self.last_layer = nn.utils.weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))\n            self.last_layer.weight_g.data.fill_(1)\n            if norm_last_layer:\n                self.last_layer.weight_g.requires_grad = False\n        else:\n            self.last_layer = None\n\n        self.last_norm = last_norm\n\n    def _init_weights(self, m):\n        if isinstance(m, nn.Linear):\n            trunc_normal_(m.weight, std=.02)\n            if isinstance(m, nn.Linear) and m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n\n    def forward(self, x):\n        x = self.mlp(x)\n        if self.last_layer is not None:\n            x = nn.functional.normalize(x, dim=-1, p=2)\n            x = self.last_layer(x)\n        if self.last_norm is not None:\n            x = self.last_norm(x)\n        return x\n\n    def _build_norm(self, norm, hidden_dim, **kwargs):\n        if norm == 'bn':\n            norm = nn.BatchNorm1d(hidden_dim, **kwargs)\n        elif norm == 'syncbn':\n            norm = nn.SyncBatchNorm(hidden_dim, **kwargs)\n        elif norm == 'csyncbn':\n            norm = CSyncBatchNorm(hidden_dim, **kwargs)\n        elif norm == 'psyncbn':\n            norm =  PSyncBatchNorm(hidden_dim, **kwargs)\n        elif norm == 'ln':\n            norm = nn.LayerNorm(hidden_dim, **kwargs)\n        else:\n            assert norm is None, \"unknown norm type {}\".format(norm)\n        return norm\n\n    def _build_act(self, act):\n        if act == 'relu':\n            act = nn.ReLU()\n        elif act == 'gelu':\n            act = nn.GELU()\n        else:\n            assert False, \"unknown act type {}\".format(act)\n        return act\n\nclass iBOTHead(DINOHead):\n\n    def __init__(self, *args, patch_out_dim=8192, norm=None, act='gelu', last_norm=None, \n                 nlayers=3, hidden_dim=2048, bottleneck_dim=256, norm_last_layer=True, \n                 shared_head=False, **kwargs):\n        \n        super(iBOTHead, self).__init__(*args,\n                                        norm=norm,\n                                        act=act,\n                                        last_norm=last_norm,\n                                        nlayers=nlayers,\n                                        hidden_dim=hidden_dim,\n                                        bottleneck_dim=bottleneck_dim,\n                                        norm_last_layer=norm_last_layer, \n                                        **kwargs)\n\n        if not shared_head:\n            if bottleneck_dim > 0:\n                self.last_layer2 = nn.utils.weight_norm(nn.Linear(bottleneck_dim, patch_out_dim, bias=False))\n                self.last_layer2.weight_g.data.fill_(1)\n                if norm_last_layer:\n                    self.last_layer2.weight_g.requires_grad = False\n            else:\n                self.mlp2 = nn.Linear(hidden_dim, patch_out_dim)\n                self.last_layer2 = None\n\n            self.last_norm2 = self._build_norm(last_norm, patch_out_dim, affine=False, **kwargs)\n        else:\n            if bottleneck_dim > 0:\n                self.last_layer2 = self.last_layer\n            else:\n                self.mlp2 = self.mlp[-1]\n                self.last_layer2 = None\n\n            self.last_norm2 = self.last_norm\n\n    def forward(self, x):\n        if len(x.shape) == 2:\n            return super(iBOTHead, self).forward(x)\n\n        if self.last_layer is not None:\n            x = self.mlp(x)\n            x = nn.functional.normalize(x, dim=-1, p=2)\n            x1 = self.last_layer(x[:, 0])\n            x2 = self.last_layer2(x[:, 1:])\n        else:\n            x = self.mlp[:-1](x)\n            x1 = self.mlp[-1](x[:, 0])\n            x2 = self.mlp2(x[:, 1:])\n        \n        if self.last_norm is not None:\n            x1 = self.last_norm(x1)\n            x2 = self.last_norm2(x2)\n        \n        return x1, x2\n"
  },
  {
    "path": "downstream_tasks/detection/models/swin_transformer.py",
    "content": "# Copyright (c) ByteDance, Inc. and its affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the license found in the\n# LICENSE file in the root directory of this source tree.\n\n\"\"\"\nMostly copy-paste from Swin-Transformer libarary:\nhttps://github.com/facebookresearch/dino\nhttps://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py\n\"\"\"\n\nimport os\nimport logging\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.distributed as dist\n\nfrom math import sqrt\nfrom functools import partial\nfrom timm.models.layers import DropPath, to_2tuple, trunc_normal_\nfrom timm.models.registry import register_model\n\nclass Mlp(nn.Module):\n    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):\n        super(Mlp, self).__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        self.fc1 = nn.Linear(in_features, hidden_features)\n        self.act = act_layer()\n        self.fc2 = nn.Linear(hidden_features, out_features)\n        self.drop = nn.Dropout(drop)\n\n    def forward(self, x):\n        x = self.fc1(x)\n        x = self.act(x)\n        x = self.drop(x)\n        x = self.fc2(x)\n        x = self.drop(x)\n        return x\n\n\ndef window_partition(x, window_size):\n    \"\"\"\n    Args:\n        x: (B, H, W, C)\n        window_size (int): window size\n\n    Returns:\n        windows: (num_windows*B, window_size, window_size, C)\n    \"\"\"\n    B, H, W, C = x.shape\n    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)\n    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)\n    return windows\n\n\ndef window_reverse(windows, window_size, H, W):\n    \"\"\"\n    Args:\n        windows: (num_windows*B, window_size, window_size, C)\n        window_size (int): Window size\n        H (int): Height of image\n        W (int): Width of image\n\n    Returns:\n        x: (B, H, W, C)\n    \"\"\"\n    B = int(windows.shape[0] / (H * W / window_size / window_size))\n    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)\n    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)\n    return x\n\n\nclass WindowAttention(nn.Module):\n    r\"\"\"Window based multi-head self attention (W-MSA) module with relative position bias.\n    It supports both of shifted and non-shifted window.\n\n    Args:\n        dim (int): Number of input channels.\n        window_size (tuple[int]): The height and width of the window.\n        num_heads (int): Number of attention heads.\n        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True\n        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set\n        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0\n        proj_drop (float, optional): Dropout ratio of output. Default: 0.0\n    \"\"\"\n\n    def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):\n\n        super(WindowAttention, self).__init__()\n        self.dim = dim\n        self.window_size = window_size  # Wh, Ww\n        self.num_heads = num_heads\n        head_dim = dim // num_heads\n        self.scale = qk_scale or head_dim ** -0.5\n\n        # define a parameter table of relative position bias\n        self.relative_position_bias_table = nn.Parameter(\n            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH\n\n        # get pair-wise relative position index for each token inside the window\n        coords_h = torch.arange(self.window_size[0])\n        coords_w = torch.arange(self.window_size[1])\n        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww\n        coords_flatten = torch.flatten(coords, 1)  # 2 Wh*Ww\n        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww\n        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2\n        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0\n        relative_coords[:, :, 1] += self.window_size[1] - 1\n        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1\n        relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww\n        self.register_buffer(\"relative_position_index\", relative_position_index)\n\n        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)\n        self.attn_drop = nn.Dropout(attn_drop)\n        self.proj = nn.Linear(dim, dim)\n        self.proj_drop = nn.Dropout(proj_drop)\n\n        trunc_normal_(self.relative_position_bias_table, std=.02)\n        self.softmax = nn.Softmax(dim=-1)\n\n    def forward(self, x, mask=None):\n        \"\"\"\n        Args:\n            x: input features with shape of (num_windows*B, N, C)\n            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None\n        \"\"\"\n        B_, N, C = x.shape\n        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)\n        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)\n\n        q = q * self.scale\n        attn = (q @ k.transpose(-2, -1))\n\n        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(\n            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nH\n        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww\n        attn = attn + relative_position_bias.unsqueeze(0)\n\n        if mask is not None:\n            nW = mask.shape[0]\n            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)\n            attn = attn.view(-1, self.num_heads, N, N)\n            attn = self.softmax(attn)\n        else:\n            attn = self.softmax(attn)\n\n        attn_out = attn\n        attn = self.attn_drop(attn)\n\n        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)\n        x = self.proj(x)\n        x = self.proj_drop(x)\n        return x, attn_out\n\n    def extra_repr(self) -> str:\n        return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'\n\n    def flops(self, N):\n        # calculate flops for 1 window with token length of N\n        flops = 0\n        # qkv = self.qkv(x)\n        flops += N * self.dim * 3 * self.dim\n        # attn = (q @ k.transpose(-2, -1))\n        flops += self.num_heads * N * (self.dim // self.num_heads) * N\n        #  x = (attn @ v)\n        flops += self.num_heads * N * N * (self.dim // self.num_heads)\n        # x = self.proj(x)\n        flops += N * self.dim * self.dim\n        return flops\n\n    @staticmethod\n    def compute_macs(module, input, output):\n        B, N, C = input[0].shape\n\n        module.__flops__ += module.flops(N) * B\n\n\nclass SwinTransformerBlock(nn.Module):\n    r\"\"\"Swin Transformer Block.\n\n    Args:\n        dim (int): Number of input channels.\n        input_resolution (tuple[int]): Input resulotion.\n        num_heads (int): Number of attention heads.\n        window_size (int): Window size.\n        shift_size (int): Shift size for SW-MSA.\n        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.\n        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True\n        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.\n        drop (float, optional): Dropout rate. Default: 0.0\n        attn_drop (float, optional): Attention dropout rate. Default: 0.0\n        drop_path (float, optional): Stochastic depth rate. Default: 0.0\n        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU\n        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm\n    \"\"\"\n\n    def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,\n                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,\n                 act_layer=nn.GELU, norm_layer=nn.LayerNorm):\n        super().__init__()\n        self.dim = dim\n        self.input_resolution = input_resolution\n        self.num_heads = num_heads\n        self.window_size = window_size\n        self.shift_size = shift_size\n        self.mlp_ratio = mlp_ratio\n        if min(self.input_resolution) <= self.window_size:\n            # if window size is larger than input resolution, we don't partition windows\n            self.shift_size = 0\n            self.window_size = min(self.input_resolution)\n        assert 0 <= self.shift_size < self.window_size, \"shift_size must in 0-window_size\"\n\n        self.norm1 = norm_layer(dim)\n        self.attn = WindowAttention(\n            dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,\n            qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)\n\n        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()\n        self.norm2 = norm_layer(dim)\n        mlp_hidden_dim = int(dim * mlp_ratio)\n        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)\n\n        self.H = input_resolution[0]\n        self.W = input_resolution[1]\n\n        self.attn_mask_dict = {} # {self.H: self.create_attn_mask(self.H, self.W)}\n\n\n    def create_attn_mask(self, H, W):\n        # calculate attention mask for SW-MSA\n\n        Hp = int(np.ceil(H / self.window_size)) * self.window_size\n        Wp = int(np.ceil(W / self.window_size)) * self.window_size\n        img_mask = torch.zeros((1, Hp, Wp, 1))  # 1 Hp Wp 1\n        h_slices = (slice(0, -self.window_size),\n                    slice(-self.window_size, -self.shift_size),\n                    slice(-self.shift_size, None))\n        w_slices = (slice(0, -self.window_size),\n                    slice(-self.window_size, -self.shift_size),\n                    slice(-self.shift_size, None))\n        cnt = 0\n        for h in h_slices:\n            for w in w_slices:\n                img_mask[:, h, w, :] = cnt\n                cnt += 1\n\n        mask_windows = window_partition(img_mask, self.window_size)  # nW, window_size, window_size, 1\n        mask_windows = mask_windows.view(-1, self.window_size * self.window_size)\n        attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)\n        attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))\n\n        return attn_mask\n\n\n    def forward(self, x):\n\n        B, L, C = x.shape\n        H = int(sqrt(L))\n        W = H\n\n        shortcut = x\n        x = self.norm1(x)\n        x = x.view(B, H, W, C)\n\n        # pad feature maps to multiples of window size\n        pad_l = pad_t = 0\n        pad_r = (self.window_size - W % self.window_size) % self.window_size\n        pad_b = (self.window_size - H % self.window_size) % self.window_size\n        x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))\n        _, Hp, Wp, _ = x.shape\n\n        # cyclic shift\n        if self.shift_size > 0:\n            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))\n\n            if H is self.attn_mask_dict.keys():\n                attn_mask = self.attn_mask_dict[H]\n            else:\n                self.attn_mask_dict[H] = self.create_attn_mask(H, W).to(x.device)\n                attn_mask = self.attn_mask_dict[H]\n\n        else:\n            shifted_x = x\n            attn_mask = None\n\n        # partition windows\n        x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C\n        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # nW*B, window_size*window_size, C\n\n        # W-MSA/SW-MSA\n        attn_windows, attn = self.attn(x_windows, attn_mask)  # nW*B, window_size*window_size, C\n\n        # merge windows\n        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)\n        shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp)  # B H' W' C\n\n        # reverse cyclic shift\n        if self.shift_size > 0:\n            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))\n        else:\n            x = shifted_x\n\n        if pad_r > 0 or pad_b > 0:\n            x = x[:, :H, :W, :].contiguous()\n\n        x = x.view(B, H * W, C)\n\n        # FFN\n        x = shortcut + self.drop_path(x)\n        x = x + self.drop_path(self.mlp(self.norm2(x)))\n\n        return x, attn\n\n    def extra_repr(self) -> str:\n        return f\"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, \" \\\n               f\"window_size={self.window_size}, shift_size={self.shift_size} mlp_ratio={self.mlp_ratio}\"\n\n    def flops(self):\n        flops = 0\n        H, W = self.input_resolution\n        # norm1\n        flops += self.dim * H * W\n        # W-MSA/SW-MSA\n        nW = H * W / self.window_size / self.window_size\n        flops += nW * self.attn.flops(self.window_size * self.window_size)\n        # mlp\n        flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio\n        # norm2\n        flops += self.dim * H * W\n        return flops\n\n\nclass PatchMerging(nn.Module):\n    r\"\"\"Patch Merging Layer.\n\n    Args:\n        input_resolution (tuple[int]): Resolution of input feature.\n        dim (int): Number of input channels.\n        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm\n    \"\"\"\n\n    def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):\n        super().__init__()\n        self.input_resolution = input_resolution\n        self.dim = dim\n        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)\n        self.norm = norm_layer(4 * dim)\n\n    def forward(self, x):\n        \"\"\" Forward function.\n        Args:\n            x: Input feature, tensor size (B, H*W, C).\n            H, W: Spatial resolution of the input feature.\n        \"\"\"\n        B, L, C = x.shape\n        H = int(sqrt(L))\n        W = H\n\n        x = x.view(B, H, W, C)\n\n        # padding\n        pad_input = (H % 2 == 1) or (W % 2 == 1)\n        if pad_input:\n            x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))\n\n        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C\n        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C\n        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C\n        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C\n        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C\n        x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C\n\n        x = self.norm(x)\n        x = self.reduction(x)\n\n        return x\n\n\n    def extra_repr(self) -> str:\n        return f\"input_resolution={self.input_resolution}, dim={self.dim}\"\n\n    def flops(self):\n        H, W = self.input_resolution\n        flops = H * W * self.dim\n        flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim\n        return flops\n\n\nclass BasicLayer(nn.Module):\n    \"\"\"A basic Swin Transformer layer for one stage.\n\n    Args:\n        dim (int): Number of input channels.\n        input_resolution (tuple[int]): Input resulotion.\n        depth (int): Number of blocks.\n        num_heads (int): Number of attention heads.\n        window_size (int): Window size.\n        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.\n        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True\n        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.\n        drop (float, optional): Dropout rate. Default: 0.0\n        attn_drop (float, optional): Attention dropout rate. Default: 0.0\n        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0\n        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm\n        downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None\n    \"\"\"\n\n    def __init__(self, dim, input_resolution, depth, num_heads, window_size,\n                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,\n                 drop_path=0., norm_layer=nn.LayerNorm, downsample=None):\n\n        super().__init__()\n        self.dim = dim\n        self.input_resolution = input_resolution\n        self.depth = depth\n\n        self.blocks = nn.ModuleList([\n            SwinTransformerBlock(dim=dim, input_resolution=input_resolution,\n                                 num_heads=num_heads, window_size=window_size,\n                                 shift_size=0 if (i % 2 == 0) else window_size // 2,\n                                 mlp_ratio=mlp_ratio,\n                                 qkv_bias=qkv_bias, qk_scale=qk_scale,\n                                 drop=drop, attn_drop=attn_drop,\n                                 drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,\n                                 norm_layer=norm_layer)\n            for i in range(depth)])\n        if downsample is not None:\n            self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)\n        else:\n            self.downsample = None\n\n    def forward(self, x):\n        for blk in self.blocks:\n            x, _ = blk(x)\n        if self.downsample is not None:\n            x = self.downsample(x)\n        return x\n\n    def forward_with_features(self, x):\n        fea = []\n        for blk in self.blocks:\n            x, _ = blk(x)\n            fea.append(x)\n        if self.downsample is not None:\n            x = self.downsample(x)\n        return x, fea\n\n    def forward_with_attention(self, x):\n        attns = []\n        for blk in self.blocks:\n            x, attn = blk(x)\n            attns.append(attn)\n        if self.downsample is not None:\n            x = self.downsample(x)\n        return x, attns\n\n\n    def extra_repr(self) -> str:\n        return f\"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}\"\n\n    def flops(self):\n        flops = 0\n        for blk in self.blocks:\n            flops += blk.flops()\n        if self.downsample is not None:\n            flops += self.downsample.flops()\n        return flops\n\n\nclass PatchEmbed(nn.Module):\n    \"\"\" Image to Patch Embedding\n    \"\"\"\n\n    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None):\n        super().__init__()\n        img_size = to_2tuple(img_size)\n        patch_size = to_2tuple(patch_size)\n        patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]\n        self.img_size = img_size\n        self.patch_size = patch_size\n        self.patches_resolution = patches_resolution\n        self.num_patches = patches_resolution[0] * patches_resolution[1]\n\n        self.in_chans = in_chans\n        self.embed_dim = embed_dim\n\n        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)\n        if norm_layer is not None:\n            self.norm = norm_layer(embed_dim)\n        else:\n            self.norm = None\n\n    def forward(self, x):\n\n        # # FIXME look at relaxing size constraints\n        # assert H == self.img_size[0] and W == self.img_size[1], \\\n        #     f\"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).\"\n\n        x = self.proj(x)\n        B, C, H, W = x.shape\n        x = x.flatten(2).transpose(1, 2)  # B Ph*Pw C\n        if self.norm is not None:\n            x = self.norm(x)\n        return x.transpose(1, 2).reshape(B, C, H, W)\n\n\n    def flops(self):\n        Ho, Wo = self.patches_resolution\n        flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])\n        if self.norm is not None:\n            flops += Ho * Wo * self.embed_dim\n        return flops\n\n\nclass SwinTransformer(nn.Module):\n    r\"\"\" Swin Transformer\n        A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows`  -\n          https://arxiv.org/pdf/2103.14030\n\n    Args:\n        img_size (int | tuple(int)): Input image size.\n        patch_size (int | tuple(int)): Patch size.\n        in_chans (int): Number of input channels.\n        num_classes (int): Number of classes for classification head.\n        embed_dim (int): Embedding dimension.\n        depths (tuple(int)): Depth of Swin Transformer layers.\n        num_heads (tuple(int)): Number of attention heads in different layers.\n        window_size (int): Window size.\n        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.\n        qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: Truee\n        qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.\n        drop_rate (float): Dropout rate.\n        attn_drop_rate (float): Attention dropout rate.\n        drop_path_rate (float): Stochastic depth rate.\n        norm_layer (nn.Module): normalization layer.\n        ape (bool): If True, add absolute position embedding to the patch embedding.\n        patch_norm (bool): If True, add normalization after patch embedding.\n    \"\"\"\n\n    def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000,\n                 embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],\n                 window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,\n                 drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,\n                 norm_layer=partial(nn.LayerNorm, eps=1e-6), ape=False, patch_norm=True, \n                 return_all_tokens=False, use_mean_pooling=True, masked_im_modeling=False):\n\n        super().__init__()\n\n        self.num_classes = num_classes\n        self.depths = depths\n        self.num_layers = len(depths)\n        self.embed_dim = embed_dim\n        self.ape = ape\n        self.patch_norm = patch_norm\n        self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))\n        self.mlp_ratio = mlp_ratio\n        self.return_all_tokens = return_all_tokens\n\n        self.patch_embed = PatchEmbed(\n            img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,\n            norm_layer=norm_layer if self.patch_norm else None)\n        num_patches = self.patch_embed.num_patches\n        patches_resolution = self.patch_embed.patches_resolution\n        self.patches_resolution = patches_resolution\n\n        if self.ape:\n            self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))\n            trunc_normal_(self.absolute_pos_embed, std=.02)\n\n        self.pos_drop = nn.Dropout(p=drop_rate)\n\n        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule\n        self.layers = nn.ModuleList()\n        for i_layer in range(self.num_layers):\n            layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),\n                               input_resolution=(patches_resolution[0] // (2 ** i_layer),\n                                                 patches_resolution[1] // (2 ** i_layer)),\n                               depth=depths[i_layer],\n                               num_heads=num_heads[i_layer],\n                               window_size=window_size,\n                               mlp_ratio=self.mlp_ratio,\n                               qkv_bias=qkv_bias, qk_scale=qk_scale,\n                               drop=drop_rate, attn_drop=attn_drop_rate,\n                               drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],\n                               norm_layer=norm_layer,\n                               downsample=PatchMerging if (i_layer < self.num_layers - 1) else None)\n            self.layers.append(layer)\n\n        self.norm = norm_layer(self.num_features)\n        self.avgpool = nn.AdaptiveAvgPool1d(1)\n        self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()\n\n        self.apply(self._init_weights)\n\n        # masked image modeling\n        self.masked_im_modeling = masked_im_modeling\n        if masked_im_modeling:\n            self.masked_embed = nn.Parameter(torch.zeros(1, embed_dim))\n\n    def _init_weights(self, m):\n        if isinstance(m, nn.Linear):\n            trunc_normal_(m.weight, std=.02)\n            if isinstance(m, nn.Linear) and m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n        elif isinstance(m, nn.LayerNorm):\n            nn.init.constant_(m.bias, 0)\n            nn.init.constant_(m.weight, 1.0)\n\n    @torch.jit.ignore\n    def no_weight_decay(self):\n        return {'absolute_pos_embed'}\n\n    @torch.jit.ignore\n    def no_weight_decay_keywords(self):\n        # todo: to be implemented\n        return {'relative_position_bias_table'}\n\n    def forward(self, x, return_all_tokens=None, mask=None):\n        # patch linear embedding\n        x = self.patch_embed(x)\n        # mask image modeling\n        if mask is not None:\n            x = self.mask_model(x, mask)\n        x = x.flatten(2).transpose(1, 2)\n\n        if self.ape:\n            x = x + self.absolute_pos_embed\n        x = self.pos_drop(x)\n\n        for layer in self.layers:\n            x = layer(x)\n\n        x_region = self.norm(x)  # B L C\n        x = self.avgpool(x_region.transpose(1, 2))  # B C 1\n        x = torch.flatten(x, 1)\n\n        return_all_tokens = self.return_all_tokens if \\\n            return_all_tokens is None else return_all_tokens\n        if return_all_tokens:\n            return torch.cat([x.unsqueeze(1), x_region], dim=1)\n        return x\n\n    def get_selfattention(self, x, n=1):\n        # n=1 return the last layer attn map; otherwise return attn maps in all layers\n        x = self.patch_embed(x)\n        x = x.flatten(2).transpose(1, 2)\n        if self.ape:\n            x = x + self.absolute_pos_embed\n        x = self.pos_drop(x)\n\n        if n==1:\n            return self.get_last_selfattention(x)\n        else:\n            return self.get_all_selfattention(x)\n\n    def get_last_selfattention(self, x):\n\n        for i, layer in enumerate(self.layers):\n            if i < len(self.layers) - 1:\n                x = layer(x)\n            else:\n                x, attns = layer.forward_with_attention(x)\n                return attns[-1]\n\n    def get_all_selfattention(self, x):\n        attn_out = []\n\n        for layer in self.layers:\n            x, attns = layer.forward_with_attention(x)\n            attn_out += attns\n\n        return attn_out\n\n    def get_intermediate_layers(self, x, n=1, return_patch_avgpool=False):\n\n        num_blks = sum(self.depths)\n        start_idx = num_blks - n\n\n        sum_cur = 0\n        for i, d in enumerate(self.depths):\n            sum_cur_new = sum_cur + d\n            if start_idx >= sum_cur and start_idx < sum_cur_new:\n                start_stage = i\n                start_blk = start_idx - sum_cur\n            sum_cur = sum_cur_new\n\n\n        x = self.patch_embed(x)\n        x = x.flatten(2).transpose(1, 2)\n        if self.ape:\n            x = x + self.absolute_pos_embed\n        x = self.pos_drop(x)\n\n        # we will return the averaged token features from the `n` last blocks\n        # note: there is no [CLS] token in Swin Transformer\n        output = []\n        s = 0\n        for i, layer in enumerate(self.layers):\n            x, fea = layer.forward_with_features(x)\n\n            if i >= start_stage:\n                for x_ in fea[start_blk:]:\n\n                    if i == len(self.layers)-1: # use the norm in the last stage\n                        x_ = self.norm(x_)\n\n                    x_avg = torch.flatten(self.avgpool(x_.transpose(1, 2)), 1)  # B C \n                    if return_patch_avgpool:\n                        x_o = x_avg\n                    else:\n                        x_o = torch.cat((x_avg.unsqueeze(1), x_), dim=1)\n                    # print(f'Stage {i},  x_o {x_o.shape}')          \n                    output.append(x_o)\n\n                start_blk = 0\n\n        #return torch.cat(output, dim=-1)\n        return output\n\n    def flops(self):\n        flops = 0\n        flops += self.patch_embed.flops()\n        for i, layer in enumerate(self.layers):\n            flops += layer.flops()\n            if dist.get_rank() == 0:\n                print(f\"GFLOPs layer_{i}: {layer.flops() / 1e9}\")\n        flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers)\n        flops += self.num_features * self.num_classes\n        return flops\n\n    def init_weights(self, pretrained='', pretrained_layers=[], verbose=True):\n        if os.path.isfile(pretrained):\n            pretrained_dict = torch.load(pretrained, map_location='cpu')\n            logging.info(f'=> loading pretrained model {pretrained}')\n            model_dict = self.state_dict()\n            pretrained_dict = {\n                k: v for k, v in pretrained_dict.items()\n                if k in model_dict.keys()\n            }\n            need_init_state_dict = {}\n            for k, v in pretrained_dict.items():\n                need_init = (\n                        k.split('.')[0] in pretrained_layers\n                        or pretrained_layers[0] is '*'\n                        or 'relative_position_index' not in k\n                        or 'attn_mask' not in k\n                )\n\n                if need_init:\n                    if verbose:\n                        logging.info(f'=> init {k} from {pretrained}')\n\n                    if 'relative_position_bias_table' in k and v.size() != model_dict[k].size():\n                        relative_position_bias_table_pretrained = v\n                        relative_position_bias_table_current = model_dict[k]\n                        L1, nH1 = relative_position_bias_table_pretrained.size()\n                        L2, nH2 = relative_position_bias_table_current.size()\n                        if nH1 != nH2:\n                            logging.info(f\"Error in loading {k}, passing\")\n                        else:\n                            if L1 != L2:\n                                logging.info(\n                                    '=> load_pretrained: resized variant: {} to {}'\n                                        .format((L1, nH1), (L2, nH2))\n                                )\n                                S1 = int(L1 ** 0.5)\n                                S2 = int(L2 ** 0.5)\n                                relative_position_bias_table_pretrained_resized = torch.nn.functional.interpolate(\n                                    relative_position_bias_table_pretrained.permute(1, 0).view(1, nH1, S1, S1),\n                                    size=(S2, S2),\n                                    mode='bicubic')\n                                v = relative_position_bias_table_pretrained_resized.view(nH2, L2).permute(1, 0)\n\n                    if 'absolute_pos_embed' in k and v.size() != model_dict[k].size():\n                        absolute_pos_embed_pretrained = v\n                        absolute_pos_embed_current = model_dict[k]\n                        _, L1, C1 = absolute_pos_embed_pretrained.size()\n                        _, L2, C2 = absolute_pos_embed_current.size()\n                        if C1 != C1:\n                            logging.info(f\"Error in loading {k}, passing\")\n                        else:\n                            if L1 != L2:\n                                logging.info(\n                                    '=> load_pretrained: resized variant: {} to {}'\n                                        .format((1, L1, C1), (1, L2, C2))\n                                )\n                                S1 = int(L1 ** 0.5)\n                                S2 = int(L2 ** 0.5)\n                                absolute_pos_embed_pretrained = absolute_pos_embed_pretrained.reshape(-1, S1, S1, C1)\n                                absolute_pos_embed_pretrained = absolute_pos_embed_pretrained.permute(0, 3, 1, 2)\n                                absolute_pos_embed_pretrained_resized = torch.nn.functional.interpolate(\n                                    absolute_pos_embed_pretrained, size=(S2, S2), mode='bicubic')\n                                v = absolute_pos_embed_pretrained_resized.permute(0, 2, 3, 1).flatten(1, 2)\n\n                    need_init_state_dict[k] = v\n            self.load_state_dict(need_init_state_dict, strict=False)\n\n    def freeze_pretrained_layers(self, frozen_layers=[]):\n        for name, module in self.named_modules():\n            if (\n                    name.split('.')[0] in frozen_layers\n                    or '.'.join(name.split('.')[0:2]) in frozen_layers\n                    or (len(frozen_layers) > 0 and frozen_layers[0] is '*')\n            ):\n                for _name, param in module.named_parameters():\n                    param.requires_grad = False\n                logging.info(\n                    '=> set param {} requires grad to False'\n                        .format(name)\n                )\n        for name, param in self.named_parameters():\n            if (\n                    name.split('.')[0] in frozen_layers\n                    or (len(frozen_layers) > 0 and frozen_layers[0] is '*')\n                    and param.requires_grad is True\n            ):\n                param.requires_grad = False\n                logging.info(\n                    '=> set param {} requires grad to False'\n                        .format(name)\n                )\n        return self\n\n    def get_num_layers(self):\n        #return len(self.layers)\n        return sum(self.depths)\n\n    def mask_model(self, x, mask):\n        # extend mask for hierarchical features\n        if x.shape[-2:] != mask.shape[-2:]:\n            htimes, wtimes = np.array(x.shape[-2:]) // np.array(mask.shape[-2:])\n            mask = mask.repeat_interleave(htimes, -2).repeat_interleave(wtimes, -1)\n        \n        # mask embed\n        x.permute(0, 2, 3, 1)[mask, :] = self.masked_embed.to(x.dtype)\n\n        return x\n\n@register_model\ndef swin_tiny(window_size=7, **kwargs):\n    model = SwinTransformer(\n        window_size=window_size, embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],\n        mlp_ratio=4, qkv_bias=True, drop_path_rate=kwargs.pop('drop_path_rate', 0.1), **kwargs)\n    return model\n\n@register_model\ndef swin_small(window_size=7, **kwargs):\n    model = SwinTransformer(\n        window_size=window_size, embed_dim=96, depths=[2, 2, 18, 2], num_heads=[3, 6, 12, 24],\n        mlp_ratio=4, qkv_bias=True, drop_path_rate=kwargs.pop('drop_path_rate', 0.2), **kwargs)\n    return model\n\n@register_model\ndef swin_base(window_size=7, **kwargs):\n    model = SwinTransformer(\n        window_size=window_size, embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32],\n        mlp_ratio=4, qkv_bias=True, drop_path_rate=kwargs.pop('drop_path_rate', 0.2), **kwargs)\n    return model\n\n@register_model\ndef swin_large(window_size=7, **kwargs):\n    model = SwinTransformer(\n        window_size=window_size, embed_dim=192, depths=[2, 2, 18, 2], num_heads=[6, 12, 24, 48],\n        mlp_ratio=4, qkv_bias=True, drop_path_rate=kwargs.pop('drop_path_rate', 0.2), **kwargs)\n    return model"
  },
  {
    "path": "downstream_tasks/detection/models/vision_transformer.py",
    "content": "# Copyright (c) ByteDance, Inc. and its affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the license found in the\n# LICENSE file in the root directory of this source tree.\n\n\"\"\"\nMostly copy-paste from DINO and timm library:\nhttps://github.com/facebookresearch/dino\nhttps://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py\n\"\"\"\n\nimport math\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom functools import partial\nfrom utils import trunc_normal_\nfrom timm.models.registry import register_model\n\ndef drop_path(x, drop_prob: float = 0., training: bool = False):\n    if drop_prob == 0. or not training:\n        return x\n    keep_prob = 1 - drop_prob\n    shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets\n    random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)\n    random_tensor.floor_()  # binarize\n    output = x.div(keep_prob) * random_tensor\n    return output\n\n\nclass DropPath(nn.Module):\n    \"\"\"Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).\n    \"\"\"\n    def __init__(self, drop_prob=None):\n        super(DropPath, self).__init__()\n        self.drop_prob = drop_prob\n\n    def forward(self, x):\n        return drop_path(x, self.drop_prob, self.training)\n\n\nclass Mlp(nn.Module):\n    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        self.fc1 = nn.Linear(in_features, hidden_features)\n        self.act = act_layer()\n        self.fc2 = nn.Linear(hidden_features, out_features)\n        self.drop = nn.Dropout(drop)\n\n    def forward(self, x):\n        x = self.fc1(x)\n        x = self.act(x)\n        x = self.drop(x)\n        x = self.fc2(x)\n        x = self.drop(x)\n        return x\n\n\nclass Attention(nn.Module):\n    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., window_size=None):\n        super().__init__()\n        self.num_heads = num_heads\n        head_dim = dim // num_heads\n        self.scale = qk_scale or head_dim ** -0.5\n\n        #self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)\n        ####add by wxd\n        self.qkv = nn.Linear(dim, dim * 3, bias=False)\n        all_head_dim = head_dim * self.num_heads\n        if qkv_bias:\n            self.q_bias = nn.Parameter(torch.zeros(all_head_dim))\n            self.v_bias = nn.Parameter(torch.zeros(all_head_dim))\n        else:\n            self.q_bias = None\n            self.v_bias = None\n        \n\n\n        self.attn_drop = nn.Dropout(attn_drop)\n        self.proj = nn.Linear(dim, dim)\n        self.proj_drop = nn.Dropout(proj_drop)\n\n    def forward(self, x, x_rel_pos_bias = None):\n        B, N, C = x.shape\n        #qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)\n        ########add by wxd\n        qkv_bias = None\n        if self.q_bias is not None:\n            qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))\n        qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)\n        qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)\n        q, k, v = qkv[0], qkv[1], qkv[2]\n\n        attn = (q @ k.transpose(-2, -1)) * self.scale\n\n#        if self.relative_position_bias_table is not None:\n#            relative_position_bias = \\\n#                self.relative_position_bias_table[self.relative_position_index.view(-1)].view(\n#                    self.window_size[0] * self.window_size[1] + 1,\n#                    self.window_size[0] * self.window_size[1] + 1, -1)  # Wh*Ww,Wh*Ww,nH\n#            print(\"################before relative:\", relative_position_bias.shape)\n#            relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww\n#            print(\"################after relative:\", relative_position_bias.shape)\n#            relative_position_bias = intepolate_rpe(relative_position_bias)\n#            print(\"################ater inter relative:\", relative_position_bias.shape)\n        if x_rel_pos_bias is not None:\n            attn = attn + x_rel_pos_bias.unsqueeze(0)\n\n        attn = attn.softmax(dim=-1)\n        attn = self.attn_drop(attn)\n\n        x = (attn @ v).transpose(1, 2).reshape(B, N, C)\n        x = self.proj(x)\n        x = self.proj_drop(x)\n        return x, attn\n\nclass Block(nn.Module):\n    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., \n                 attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, window_size=None, init_values=0):\n        super().__init__()\n        self.norm1 = norm_layer(dim)\n        self.attn = Attention(\n            dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, window_size=window_size)\n        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()\n        self.norm2 = norm_layer(dim)\n        mlp_hidden_dim = int(dim * mlp_ratio)\n        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)\n\n        if init_values > 0:\n            self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)\n            self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)\n        else:\n            self.gamma_1, self.gamma_2 = None, None\n\n\n\n    def forward(self, x, x_rel_pos_bias=None, return_attention=False):\n        y, attn = self.attn(self.norm1(x), x_rel_pos_bias)\n        if return_attention:\n            return attn\n        if self.gamma_1 is None:\n            x = x + self.drop_path(y)\n            x = x + self.drop_path(self.mlp(self.norm2(x)))\n        else:\n            x = x + self.drop_path(self.gamma_1 * y)\n            x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))\n        return x\n\n#class PatchEmbed(nn.Module):\n#    \"\"\" Image to Patch Embedding\n#    \"\"\"\n#    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):\n#        super().__init__()\n#        num_patches = (img_size // patch_size) * (img_size // patch_size)\n#        self.img_size = img_size\n#        self.patch_size = patch_size\n#        self.num_patches = num_patches\n#        print(\"#################patch in!!!\")\n#        self.patch_shape = (img_size // patch_size, img_size // patch_size)\n#\n#        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)\n#            \n#    def forward(self, x):\n#        B, C, H, W = x.shape\n#        return self.proj(x)\n\n\nclass PatchEmbed(nn.Module):\n    \"\"\" Image to Patch Embedding\n    \"\"\"\n    def __init__(self, img_size=[224, 224], patch_size=16, in_chans=3, embed_dim=768):\n        super().__init__()\n        self.num_patches_w = img_size[0] // patch_size\n        self.num_patches_h = img_size[1] // patch_size\n\n        num_patches = self.num_patches_w * self.num_patches_h\n        self.patch_shape = (img_size[0] // patch_size, img_size[1] // patch_size)\n        self.img_size = img_size\n        self.patch_size = patch_size\n        self.num_patches = num_patches\n        print(\"##############patch here!!!\")\n        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)\n            \n    def forward(self, x, mask=None):\n        B, C, H, W = x.shape\n        return self.proj(x)\n\nclass VisionTransformer(nn.Module):\n    \"\"\" Vision Transformer \"\"\"\n    def __init__(self, img_size=[224], patch_size=16, in_chans=3, num_classes=0, embed_dim=768, depth=12,\n                 num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,\n                 drop_path_rate=0., norm_layer=partial(nn.LayerNorm, eps=1e-6), return_all_tokens=False, \n                 init_values=0, use_sincos_pos_emb=False, use_abs_pos_emb=False, use_rel_pos_bias=False, use_mean_pooling=False, masked_im_modeling=False):\n        super().__init__()\n        self.num_features = self.embed_dim = embed_dim\n        self.return_all_tokens = return_all_tokens\n        print(\"############use_abs_pos:\", use_abs_pos_emb)\n        print(\"############use_sincos_pos:\", use_sincos_pos_emb)\n        print(\"############use_rel_pos_bias:\", use_rel_pos_bias)\n        self.use_abs_pos_emb = use_abs_pos_emb\n        self.use_sincos_pos_emb = use_sincos_pos_emb\n\n        self.patch_embed = PatchEmbed(\n            img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)\n        num_patches = self.patch_embed.num_patches\n\n        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))\n        if use_abs_pos_emb:\n            self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))\n        else:\n            self.pos_embed = None\n\n        self.pos_drop = nn.Dropout(p=drop_rate)\n\n        self.use_rel_pos_bias = use_rel_pos_bias\n\n        if self.use_rel_pos_bias:\n            print(\"=================use RelativePositionBias===================\")\n            window_size=self.patch_embed.patch_shape\n            self.window_size = window_size\n            self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3\n            self.relative_position_bias_table = nn.Parameter(\n                torch.zeros(self.num_relative_distance, num_heads))  # 2*Wh-1 * 2*Ww-1, nH\n            # cls to token & token 2 cls & cls to cls\n\n        else:\n            self.window_size = None\n            self.relative_position_bias_table = None\n            self.relative_position_index = None\n\n        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]  # stochastic depth decay rule\n        self.blocks = nn.ModuleList([\n            Block(\n                dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,\n                drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, \n                init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None)\n            for i in range(depth)])\n\n        self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)\n        self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None\n        # Classifier head\n        self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()\n        \n        if use_abs_pos_emb:\n            trunc_normal_(self.pos_embed, std=.02)\n        trunc_normal_(self.cls_token, std=.02)\n        self.apply(self._init_weights)\n\n        # masked image modeling\n        self.masked_im_modeling = masked_im_modeling\n        if masked_im_modeling:\n            self.masked_embed = nn.Parameter(torch.zeros(1, embed_dim))\n\n    def _init_weights(self, m):\n        if isinstance(m, nn.Linear):\n            trunc_normal_(m.weight, std=.02)\n            if isinstance(m, nn.Linear) and m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n        elif isinstance(m, nn.LayerNorm):\n            nn.init.constant_(m.bias, 0)\n            nn.init.constant_(m.weight, 1.0)\n\n    def interpolate_pos_encoding(self, x, w, h):\n        npatch = x.shape[1] - 1\n        print(\"############self.pos_embed:\", self.pos_embed)\n        N = self.pos_embed.shape[1] - 1\n        if npatch == N and w == h:\n            return self.pos_embed\n        class_pos_embed = self.pos_embed[:, 0]\n        patch_pos_embed = self.pos_embed[:, 1:]\n        dim = x.shape[-1]\n        w0 = w // self.patch_embed.patch_size\n        h0 = h // self.patch_embed.patch_size\n        # we add a small number to avoid floating point error in the interpolation\n        # see discussion at https://github.com/facebookresearch/dino/issues/8\n        w0, h0 = w0 + 0.1, h0 + 0.1\n        patch_pos_embed = nn.functional.interpolate(\n            patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),\n            scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),\n            mode='bicubic',\n        )\n        assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]\n        patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)\n        return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)\n\n    def prepare_tokens(self, x, mask=None):\n        B, nc, w, h = x.shape\n        # patch linear embedding\n        x = self.patch_embed(x)\n\n        # mask image modeling\n        if mask is not None:\n            x = self.mask_model(x, mask)\n        x = x.flatten(2).transpose(1, 2)\n\n        # add the [CLS] token to the embed patch tokens\n        cls_tokens = self.cls_token.expand(B, -1, -1)\n        x = torch.cat((cls_tokens, x), dim=1)\n        \n        # add positional encoding to each token\n        if self.pos_embed is not None:\n            x = x + self.interpolate_pos_encoding(x, w, h)\n\n        return self.pos_drop(x)\n\n    def forward(self, x, return_all_tokens=None, mask=None):\n        # mim\n        if self.masked_im_modeling:\n            assert mask is not None\n            x = self.prepare_tokens(x, mask=mask)\n        else:\n            x = self.prepare_tokens(x)\n\n        for blk in self.blocks:\n            x = blk(x)\n\n        x = self.norm(x)\n        if self.fc_norm is not None:\n            x[:, 0] = self.fc_norm(x[:, 1:, :].mean(1))\n        \n        return_all_tokens = self.return_all_tokens if \\\n            return_all_tokens is None else return_all_tokens\n        if return_all_tokens:\n            return x\n        return x[:, 0]\n\n    def get_last_selfattention(self, x):\n        x = self.prepare_tokens(x)\n        for i, blk in enumerate(self.blocks):\n            if i < len(self.blocks) - 1:\n                x = blk(x)\n            else:\n                # return attention of the last block\n                return blk(x, return_attention=True)\n\n    def get_intermediate_layers(self, x, n=1):\n        x = self.prepare_tokens(x)\n        # we return the output tokens from the `n` last blocks\n        output = []\n        for i, blk in enumerate(self.blocks):\n            x = blk(x)\n            if len(self.blocks) - i <= n:\n                output.append(self.norm(x))\n        return output\n        \n    def get_num_layers(self):\n        return len(self.blocks)\n\n    def mask_model(self, x, mask):\n        x.permute(0, 2, 3, 1)[mask, :] = self.masked_embed.to(x.dtype)\n        return x\n\ndef vit_tiny(patch_size=16, **kwargs):\n    model = VisionTransformer(\n        patch_size=patch_size, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4,\n        qkv_bias=True, **kwargs)\n    return model\n\ndef vit_small(patch_size=16, **kwargs):\n    model = VisionTransformer(\n        patch_size=patch_size, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4,\n        qkv_bias=True, **kwargs)\n    return model\n\ndef vit_base(patch_size=16, **kwargs):\n    model = VisionTransformer(\n        patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,\n        qkv_bias=True, **kwargs)\n    return model\n\ndef vit_large(patch_size=16, **kwargs):\n    model = VisionTransformer(\n        patch_size=patch_size, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4,\n        qkv_bias=True, **kwargs)\n    return model\n"
  },
  {
    "path": "downstream_tasks/detection/scripts/run_eval.sh",
    "content": "#!/usr/bin/env bash\n\necho \"EVAL MODEL:\"$MODEL\npython -m torch.distributed.launch --nproc_per_node=8 \\\n    evaluation/object_detection/test.py \\\n    $CONFIG \\\n    $MODEL \\\n    --launcher pytorch \\\n    --eval bbox segm \\\n    --cfg-options model.backbone.use_checkpoint=True \\\n    ${@:6}\n\n"
  },
  {
    "path": "downstream_tasks/detection/scripts/run_train_maskrcnn_vit_base.sh",
    "content": "#!/usr/bin/env bash\n\npython -m torch.distributed.launch --nproc_per_node=8 \\\n    --nnodes=$NNODES \\\n    --node_rank=$RANK \\\n    --master_addr=$ADDRESS \\\n    --master_port=$PORT \\\n    evaluation/object_detection/train.py \\\n    evaluation/object_detection/configs/mask_rcnn/vit_base_giou_4conv1f_coco_maskrcnn_1x_cae_sincos_init0.1_lr00003.py \\\n    --launcher pytorch \\\n    --work-dir $OUTPUT_DIR \\\n    --no-validate \\\n    --deterministic \\\n    --cfg-options model.backbone.use_checkpoint=True \\\n    model.pretrained=$PRETRAINED \\\n    ${@:6}\n\n"
  },
  {
    "path": "downstream_tasks/detection/scripts/run_train_maskrcnn_vit_large.sh",
    "content": "#!/usr/bin/env bash\n\npython -m torch.distributed.launch --nproc_per_node=8 \\\n    --nnodes=$NNODES \\\n    --node_rank=$RANK \\\n    --master_addr=$ADDRESS \\\n    --master_port=$PORT \\\n    evaluation/object_detection/train.py \\\n    evaluation/object_detection/configs/mask_rcnn/vit_large_giou_4conv1f_coco_maskrcnn_1x_cae_sincos_init0.1_lr00002_lrdr0.85_dp0.2.py \\\n    --launcher pytorch \\\n    --work-dir $OUTPUT_DIR \\\n    --no-validate \\\n    --deterministic \\\n    --cfg-options model.backbone.use_checkpoint=True \\\n\tmodel.pretrained=$PRETRAINED \\\n    ${@:6}\n\n"
  },
  {
    "path": "downstream_tasks/detection/utils.py",
    "content": "# Copyright (c) ByteDance, Inc. and its affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the license found in the\n# LICENSE file in the root directory of this source tree.\n\n\"\"\"\nMostly copy-paste from torchvision references or other public repos like DETR:\nhttps://github.com/facebookresearch/detr/blob/master/util/misc.py\n\"\"\"\n\nimport os\nimport sys\nimport time\nimport math\nimport json\nimport random\nimport datetime\nimport subprocess\nimport numpy as np\nimport torch\nimport torch.distributed as dist\n\nfrom collections import defaultdict, deque\nfrom pathlib import Path\nfrom torch import nn\nfrom PIL import ImageFilter, ImageOps, Image, ImageDraw\n\nclass GaussianBlur(object):\n    \"\"\"\n    Apply Gaussian Blur to the PIL image.\n    \"\"\"\n    def __init__(self, p=0.5, radius_min=0.1, radius_max=2.):\n        self.prob = p\n        self.radius_min = radius_min\n        self.radius_max = radius_max\n\n    def __call__(self, img):\n        do_it = random.random() <= self.prob\n        if not do_it:\n            return img\n\n        return img.filter(\n            ImageFilter.GaussianBlur(\n                radius=random.uniform(self.radius_min, self.radius_max)\n            )\n        )\n\n\nclass Solarization(object):\n    \"\"\"\n    Apply Solarization to the PIL image.\n    \"\"\"\n    def __init__(self, p):\n        self.p = p\n\n    def __call__(self, img):\n        if random.random() < self.p:\n            return ImageOps.solarize(img)\n        else:\n            return img\n\n\nclass PermutePatch(object):\n    \"\"\"\n    Apply Patch permutation to the PIL image.\n    \"\"\"\n    def __init__(self, psz):\n        self.psz = psz\n\n    def __call__(self, img):\n        imgs = []\n        imgwidth, imgheight = img.size\n        for i in range(0, imgheight, self.psz):\n            for j in range(0, imgwidth, self.psz):\n                box = (j, i, j+self.psz, i+self.psz)\n                imgs.append(img.crop(box))\n        random.shuffle(imgs)\n        new_img = Image.new('RGB', (imgwidth, imgheight))\n        k = 0\n        for i in range(0, imgheight, self.psz):\n            for j in range(0, imgwidth, self.psz):\n                new_img.paste(imgs[k], (j, i))\n                k += 1\n        return new_img\n\nclass HideAndSeek(object):\n    \"\"\"\n    Apply Patch permutation to the PIL image.\n    \"\"\"\n    def __init__(self, ratio, psz):\n        self.ratio = ratio\n        self.psz = psz\n\n    def __call__(self, img):\n        imgwidth, imgheight = img.size \n        numw, numh = imgwidth // self.psz, imgheight // self.psz\n        mask_num = int(numw * numh * self.ratio)\n        mask_patch = np.random.choice(np.arange(numw * numh), mask_num, replace=False)\n        mask_w, mask_h = mask_patch % numh, mask_patch // numh\n        # img.save('test1.png')\n        draw = ImageDraw.Draw(img)\n        for mw, mh in zip(mask_w, mask_h):\n            draw.rectangle((mw * self.psz, \n                            mh * self.psz,\n                            (mw + 1) * self.psz,\n                            (mh + 1) * self.psz), fill=\"black\")\n        # img.save('test2.png')\n        return img\n\ndef load_pretrained_weights(model, pretrained_weights, checkpoint_key, model_name, patch_size):\n    if os.path.isfile(pretrained_weights):\n        state_dict = torch.load(pretrained_weights, map_location=\"cpu\")\n        if checkpoint_key is not None and checkpoint_key in state_dict:\n            print(f\"Take key {checkpoint_key} in provided checkpoint dict\")\n            state_dict = state_dict[checkpoint_key]\n        # remove `module.` prefix\n        state_dict = {k.replace(\"module.\", \"\"): v for k, v in state_dict.items()}\n        # remove `backbone.` prefix induced by multicrop wrapper\n        state_dict = {k.replace(\"backbone.\", \"\"): v for k, v in state_dict.items()}\n        msg = model.load_state_dict(state_dict, strict=False)\n        print('Pretrained weights found at {} and loaded with msg: {}'.format(pretrained_weights, msg))\n        return\n    elif pretrained_weights == 'download':\n        url = None\n        if model_name == \"vit_small\" and patch_size == 16:\n            url = \"dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth\"\n        elif model_name == \"vit_small\" and patch_size == 8:\n            url = \"dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth\"\n        elif model_name == \"vit_base\" and patch_size == 16:\n            url = \"dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth\"\n        elif model_name == \"vit_base\" and patch_size == 8:\n            url = \"dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth\"\n        if url is not None:\n            print(\"Since no pretrained weights are provided, we load the pretrained weights from {}.\".format(url))\n            state_dict = torch.hub.load_state_dict_from_url(url=\"https://dl.fbaipublicfiles.com/dino/\" + url)\n            model.load_state_dict(state_dict, strict=True)\n            return\n    elif pretrained_weights == 'supervised':\n        url = None\n        if model_name == \"vit_small\" and patch_size == 16:\n            url = \"deit_small_patch16_224-cd65a155.pth\"\n        elif model_name == \"vit_base\" and patch_size == 16:\n            url = \"deit_base_patch16_224-b5f2ef4d.pth\"\n        if url is not None:\n            print(\"Since no pretrained weights are provided, we load the pretrained weights from {}.\".format(url))\n            state_dict = torch.hub.load_state_dict_from_url(url=\"https://dl.fbaipublicfiles.com/deit/\" + url)\n            msg = model.load_state_dict(state_dict['model'], strict=False)\n            print('Supervised weights found at {} and loaded with msg: {}'.format(url, msg))\n            return\n    print(\"There is no reference weights available for this model => We use random weights.\")\n\n\ndef clip_gradients(model, clip):\n    norms = []\n    for name, p in model.named_parameters():\n        if p.grad is not None:\n            param_norm = p.grad.data.norm(2)\n            norms.append(param_norm.item())\n            clip_coef = clip / (param_norm + 1e-6)\n            if clip_coef < 1:\n                p.grad.data.mul_(clip_coef)\n    return norms\n\n\ndef cancel_gradients_last_layer(epoch, model, freeze_last_layer):\n    if epoch >= freeze_last_layer:\n        return\n    for n, p in model.named_parameters():\n        if \"last_layer\" in n:\n            p.grad = None\n\n\ndef restart_from_checkpoint(ckp_path, run_variables=None, **kwargs):\n    \"\"\"\n    Re-start from checkpoint\n    \"\"\"\n    if not os.path.isfile(ckp_path):\n        return\n    print(\"Found checkpoint at {}\".format(ckp_path))\n\n    # open checkpoint file\n    checkpoint = torch.load(ckp_path, map_location=\"cpu\")\n\n    # key is what to look for in the checkpoint file\n    # value is the object to load\n    # example: {'state_dict': model}\n    for key, value in kwargs.items():\n        if key in checkpoint and value is not None:\n            try:\n                msg = value.load_state_dict(checkpoint[key], strict=False)\n                print(\"=> loaded '{}' from checkpoint '{}' with msg {}\".format(key, ckp_path, msg))\n            except TypeError:\n                try:\n                    msg = value.load_state_dict(checkpoint[key])\n                    print(\"=> loaded '{}' from checkpoint: '{}'\".format(key, ckp_path))\n                except ValueError:\n                    print(\"=> failed to load '{}' from checkpoint: '{}'\".format(key, ckp_path))\n        else:\n            print(\"=> key '{}' not found in checkpoint: '{}'\".format(key, ckp_path))\n\n    # re load variable important for the run\n    if run_variables is not None:\n        for var_name in run_variables:\n            if var_name in checkpoint:\n                run_variables[var_name] = checkpoint[var_name]\n\n\ndef cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, start_warmup_value=0):\n    warmup_schedule = np.array([])\n    warmup_iters = warmup_epochs * niter_per_ep\n    if warmup_epochs > 0:\n        warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)\n\n    iters = np.arange(epochs * niter_per_ep - warmup_iters)\n    schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters)))\n\n    schedule = np.concatenate((warmup_schedule, schedule))\n    assert len(schedule) == epochs * niter_per_ep\n    return schedule\n\n\ndef bool_flag(s):\n    \"\"\"\n    Parse boolean arguments from the command line.\n    \"\"\"\n    FALSY_STRINGS = {\"off\", \"false\", \"0\"}\n    TRUTHY_STRINGS = {\"on\", \"true\", \"1\"}\n    if s.lower() in FALSY_STRINGS:\n        return False\n    elif s.lower() in TRUTHY_STRINGS:\n        return True\n    else:\n        raise argparse.ArgumentTypeError(\"invalid value for a boolean flag\")\n\n\ndef fix_random_seeds(seed=31):\n    \"\"\"\n    Fix random seeds.\n    \"\"\"\n    random.seed(seed)\n    os.environ['PYTHONHASHSEED'] = str(seed)\n    torch.manual_seed(seed)\n    torch.cuda.manual_seed_all(seed)\n    np.random.seed(seed)\n\n\nclass SmoothedValue(object):\n    \"\"\"Track a series of values and provide access to smoothed values over a\n    window or the global series average.\n    \"\"\"\n\n    def __init__(self, window_size=20, fmt=None):\n        if fmt is None:\n            fmt = \"{median:.6f} ({global_avg:.6f})\"\n        self.deque = deque(maxlen=window_size)\n        self.total = 0.0\n        self.count = 0\n        self.fmt = fmt\n\n    def update(self, value, n=1):\n        self.deque.append(value)\n        self.count += n\n        self.total += value * n\n\n    def synchronize_between_processes(self):\n        \"\"\"\n        Warning: does not synchronize the deque!\n        \"\"\"\n        if not is_dist_avail_and_initialized():\n            return\n        t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')\n        dist.barrier()\n        dist.all_reduce(t)\n        t = t.tolist()\n        self.count = int(t[0])\n        self.total = t[1]\n\n    @property\n    def median(self):\n        d = torch.tensor(list(self.deque))\n        return d.median().item()\n\n    @property\n    def avg(self):\n        d = torch.tensor(list(self.deque), dtype=torch.float32)\n        return d.mean().item()\n\n    @property\n    def global_avg(self):\n        return self.total / self.count\n\n    @property\n    def max(self):\n        return max(self.deque)\n\n    @property\n    def value(self):\n        return self.deque[-1]\n\n    def __str__(self):\n        return self.fmt.format(\n            median=self.median,\n            avg=self.avg,\n            global_avg=self.global_avg,\n            max=self.max,\n            value=self.value)\n\n\ndef reduce_dict(input_dict, average=True):\n    \"\"\"\n    Args:\n        input_dict (dict): all the values will be reduced\n        average (bool): whether to do average or sum\n    Reduce the values in the dictionary from all processes so that all processes\n    have the averaged results. Returns a dict with the same fields as\n    input_dict, after reduction.\n    \"\"\"\n    world_size = get_world_size()\n    if world_size < 2:\n        return input_dict\n    with torch.no_grad():\n        names = []\n        values = []\n        # sort the keys so that they are consistent across processes\n        for k in sorted(input_dict.keys()):\n            names.append(k)\n            values.append(input_dict[k])\n        values = torch.stack(values, dim=0)\n        dist.all_reduce(values)\n        if average:\n            values /= world_size\n        reduced_dict = {k: v for k, v in zip(names, values)}\n    return reduced_dict\n\n\nclass MetricLogger(object):\n    def __init__(self, delimiter=\"\\t\"):\n        self.meters = defaultdict(SmoothedValue)\n        self.delimiter = delimiter\n\n    def update(self, **kwargs):\n        for k, v in kwargs.items():\n            if isinstance(v, torch.Tensor):\n                v = v.item()\n            assert isinstance(v, (float, int))\n            self.meters[k].update(v)\n\n    def __getattr__(self, attr):\n        if attr in self.meters:\n            return self.meters[attr]\n        if attr in self.__dict__:\n            return self.__dict__[attr]\n        raise AttributeError(\"'{}' object has no attribute '{}'\".format(\n            type(self).__name__, attr))\n\n    def __str__(self):\n        loss_str = []\n        for name, meter in self.meters.items():\n            loss_str.append(\n                \"{}: {}\".format(name, str(meter))\n            )\n        return self.delimiter.join(loss_str)\n\n    def synchronize_between_processes(self):\n        for meter in self.meters.values():\n            meter.synchronize_between_processes()\n\n    def add_meter(self, name, meter):\n        self.meters[name] = meter\n\n    def log_every(self, iterable, print_freq, header=None):\n        i = 0\n        if not header:\n            header = ''\n        start_time = time.time()\n        end = time.time()\n        iter_time = SmoothedValue(fmt='{avg:.6f}')\n        data_time = SmoothedValue(fmt='{avg:.6f}')\n        space_fmt = ':' + str(len(str(len(iterable)))) + 'd'\n        if torch.cuda.is_available():\n            log_msg = self.delimiter.join([\n                header,\n                '[{0' + space_fmt + '}/{1}]',\n                'eta: {eta}',\n                '{meters}',\n                'time: {time}',\n                'data: {data}',\n                'max mem: {memory:.0f}'\n            ])\n        else:\n            log_msg = self.delimiter.join([\n                header,\n                '[{0' + space_fmt + '}/{1}]',\n                'eta: {eta}',\n                '{meters}',\n                'time: {time}',\n                'data: {data}'\n            ])\n        MB = 1024.0 * 1024.0\n        for obj in iterable:\n            data_time.update(time.time() - end)\n            yield obj\n            iter_time.update(time.time() - end)\n            if i % print_freq == 0 or i == len(iterable) - 1:\n                eta_seconds = iter_time.global_avg * (len(iterable) - i)\n                eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))\n                if torch.cuda.is_available():\n                    print(log_msg.format(\n                        i, len(iterable), eta=eta_string,\n                        meters=str(self),\n                        time=str(iter_time), data=str(data_time),\n                        memory=torch.cuda.max_memory_allocated() / MB))\n                else:\n                    print(log_msg.format(\n                        i, len(iterable), eta=eta_string,\n                        meters=str(self),\n                        time=str(iter_time), data=str(data_time)))\n            i += 1\n            end = time.time()\n        total_time = time.time() - start_time\n        total_time_str = str(datetime.timedelta(seconds=int(total_time)))\n        print('{} Total time: {} ({:.6f} s / it)'.format(\n            header, total_time_str, total_time / len(iterable)))\n\n\ndef get_sha():\n    cwd = os.path.dirname(os.path.abspath(__file__))\n\n    def _run(command):\n        return subprocess.check_output(command, cwd=cwd).decode('ascii').strip()\n    sha = 'N/A'\n    diff = \"clean\"\n    branch = 'N/A'\n    try:\n        sha = _run(['git', 'rev-parse', 'HEAD'])\n        subprocess.check_output(['git', 'diff'], cwd=cwd)\n        diff = _run(['git', 'diff-index', 'HEAD'])\n        diff = \"has uncommited changes\" if diff else \"clean\"\n        branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD'])\n    except Exception:\n        pass\n    message = f\"sha: {sha}, status: {diff}, branch: {branch}\"\n    return message\n\n\ndef is_dist_avail_and_initialized():\n    if not dist.is_available():\n        return False\n    if not dist.is_initialized():\n        return False\n    return True\n\n\ndef get_world_size():\n    if not is_dist_avail_and_initialized():\n        return 1\n    return dist.get_world_size()\n\n\ndef get_rank():\n    if not is_dist_avail_and_initialized():\n        return 0\n    return dist.get_rank()\n\n\ndef is_main_process():\n    return get_rank() == 0\n\n\ndef save_on_master(*args, **kwargs):\n    if is_main_process():\n        torch.save(*args, **kwargs)\n\n\ndef setup_for_distributed(is_master):\n    \"\"\"\n    This function disables printing when not in master process\n    \"\"\"\n    import builtins as __builtin__\n    builtin_print = __builtin__.print\n\n    def print(*args, **kwargs):\n        force = kwargs.pop('force', False)\n        if is_master or force:\n            builtin_print(*args, **kwargs)\n\n    __builtin__.print = print\n\n\ndef init_distributed_mode(args):\n    # launched with torch.distributed.launch\n    if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:\n        args.rank = int(os.environ[\"RANK\"])\n        args.world_size = int(os.environ['WORLD_SIZE'])\n        args.gpu = int(os.environ['LOCAL_RANK'])\n    # launched with submitit on a slurm cluster\n    elif 'SLURM_PROCID' in os.environ:\n        args.rank = int(os.environ['SLURM_PROCID'])\n        args.gpu = args.rank % torch.cuda.device_count()\n    # launched naively with `python main_dino.py`\n    # we manually add MASTER_ADDR and MASTER_PORT to env variables\n    elif torch.cuda.is_available():\n        print('Will run the code on one GPU.')\n        args.rank, args.gpu, args.world_size = 0, 0, 1\n        os.environ['MASTER_ADDR'] = '127.0.0.1'\n        os.environ['MASTER_PORT'] = '29500'\n    else:\n        print('Does not support training without GPU.')\n        sys.exit(1)\n\n    dist.init_process_group(\n        backend=\"nccl\",\n        init_method=args.dist_url,\n        world_size=args.world_size,\n        rank=args.rank,\n    )\n\n    torch.cuda.set_device(args.gpu)\n    print('| distributed init (rank {}): {}'.format(\n        args.rank, args.dist_url), flush=True)\n    dist.barrier()\n    setup_for_distributed(args.rank == 0)\n\n\ndef accuracy(output, target, topk=(1,)):\n    \"\"\"Computes the accuracy over the k top predictions for the specified values of k\"\"\"\n    maxk = max(topk)\n    batch_size = target.size(0)\n    _, pred = output.topk(maxk, 1, True, True)\n    pred = pred.t()\n    correct = pred.eq(target.reshape(1, -1).expand_as(pred))\n    return [correct[:k].reshape(-1).float().sum(0) * 100. / batch_size for k in topk]\n\n\ndef _no_grad_trunc_normal_(tensor, mean, std, a, b):\n    # Cut & paste from PyTorch official master until it's in a few official releases - RW\n    # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf\n    def norm_cdf(x):\n        # Computes standard normal cumulative distribution function\n        return (1. + math.erf(x / math.sqrt(2.))) / 2.\n\n    if (mean < a - 2 * std) or (mean > b + 2 * std):\n        warnings.warn(\"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. \"\n                      \"The distribution of values may be incorrect.\",\n                      stacklevel=2)\n\n    with torch.no_grad():\n        # Values are generated by using a truncated uniform distribution and\n        # then using the inverse CDF for the normal distribution.\n        # Get upper and lower cdf values\n        l = norm_cdf((a - mean) / std)\n        u = norm_cdf((b - mean) / std)\n\n        # Uniformly fill tensor with values from [l, u], then translate to\n        # [2l-1, 2u-1].\n        tensor.uniform_(2 * l - 1, 2 * u - 1)\n\n        # Use inverse cdf transform for normal distribution to get truncated\n        # standard normal\n        tensor.erfinv_()\n\n        # Transform to proper mean, std\n        tensor.mul_(std * math.sqrt(2.))\n        tensor.add_(mean)\n\n        # Clamp to ensure it's in the proper range\n        tensor.clamp_(min=a, max=b)\n        return tensor\n\n\ndef trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):\n    # type: (Tensor, float, float, float, float) -> Tensor\n    return _no_grad_trunc_normal_(tensor, mean, std, a, b)\n\n\nclass LARS(torch.optim.Optimizer):\n    \"\"\"\n    Almost copy-paste from https://github.com/facebookresearch/barlowtwins/blob/main/main.py\n    \"\"\"\n    def __init__(self, params, lr=0, weight_decay=0, momentum=0.9, eta=0.001,\n                 weight_decay_filter=None, lars_adaptation_filter=None):\n        defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum,\n                        eta=eta, weight_decay_filter=weight_decay_filter,\n                        lars_adaptation_filter=lars_adaptation_filter)\n        super().__init__(params, defaults)\n\n    @torch.no_grad()\n    def step(self):\n        for g in self.param_groups:\n            for p in g['params']:\n                dp = p.grad\n\n                if dp is None:\n                    continue\n\n                if p.ndim != 1:\n                    dp = dp.add(p, alpha=g['weight_decay'])\n\n                if p.ndim != 1:\n                    param_norm = torch.norm(p)\n                    update_norm = torch.norm(dp)\n                    one = torch.ones_like(param_norm)\n                    q = torch.where(param_norm > 0.,\n                                    torch.where(update_norm > 0,\n                                                (g['eta'] * param_norm / update_norm), one), one)\n                    dp = dp.mul(q)\n\n                param_state = self.state[p]\n                if 'mu' not in param_state:\n                    param_state['mu'] = torch.zeros_like(p)\n                mu = param_state['mu']\n                mu.mul_(g['momentum']).add_(dp)\n\n                p.add_(mu, alpha=-g['lr'])\n\ndef create_ds_config(args):\n    args.deepspeed_config = os.path.join(args.output_dir, \"deepspeed_config.json\")\n    with open(args.deepspeed_config, mode=\"w\") as writer:\n        ds_config = {\n            \"train_batch_size\": args.batch_size * get_world_size(),\n            \"train_micro_batch_size_per_gpu\": args.batch_size,\n            \"steps_per_print\": 1000,\n            \"optimizer\": {\n                \"type\": \"Adam\",\n                \"adam_w_mode\": True,\n                \"params\": {\n                    \"lr\": args.lr,\n                    \"weight_decay\": args.weight_decay,\n                    \"bias_correction\": True,\n                    \"betas\": [\n                        0.9,\n                        0.999\n                    ],\n                    \"eps\": 1e-8\n                }\n            },\n            \"fp16\": {\n                \"enabled\": True,\n                \"loss_scale\": 0,\n                \"initial_scale_power\": 7,\n                \"loss_scale_window\": 128\n            }\n        }\n\n        writer.write(json.dumps(ds_config, indent=2))\n\nclass MultiCropWrapper(nn.Module):\n    \"\"\"\n    Perform forward pass separately on each resolution input.\n    The inputs corresponding to a single resolution are clubbed and single\n    forward is run on the same resolution inputs. Hence we do several\n    forward passes = number of different resolutions used. We then\n    concatenate all the output features and run the head forward on these\n    concatenated features.\n    \"\"\"\n    def __init__(self, backbone, head=None):\n        super(MultiCropWrapper, self).__init__()\n        # disable layers dedicated to ImageNet labels classification\n        backbone.fc, backbone.head = nn.Identity(), nn.Identity()\n        self.backbone = backbone\n        if head is None:\n            self.head = nn.Identity()\n        else:\n            self.head = head\n\n    def forward(self, x, mask=None, return_backbone_feat=False, \n                **kwargs):\n        # convert to list\n        if not isinstance(x, list):\n            x = [x]\n            mask = [mask] if mask is not None else None\n        idx_crops = torch.cumsum(torch.unique_consecutive(\n            torch.tensor([inp.shape[-1] for inp in x]),\n            return_counts=True,\n        )[1], 0)\n        start_idx = 0\n        for end_idx in idx_crops:\n            inp_x = torch.cat(x[start_idx: end_idx])\n\n            if mask is not None:\n                inp_m = torch.cat(mask[start_idx: end_idx])\n                kwargs.update(dict(mask=inp_m))\n\n            _out = self.backbone(inp_x, **kwargs)\n            if start_idx == 0:\n                output = _out\n            else:\n                output = torch.cat((output, _out))\n            start_idx = end_idx\n        # Run the head forward on the concatenated features.\n        output_ = self.head(output)\n        if return_backbone_feat:\n            return output, output_\n        return output_\n\n\ndef get_params_groups(model):\n    regularized = []\n    not_regularized = []\n    for name, param in model.named_parameters():\n        if not param.requires_grad:\n            continue\n        # we do not regularize biases nor Norm parameters\n        if name.endswith(\".bias\") or len(param.shape) == 1:\n            not_regularized.append(param)\n        else:\n            regularized.append(param)\n    return [{'params': regularized}, {'params': not_regularized, 'weight_decay': 0.}]\n\n\ndef has_batchnorms(model):\n    bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm)\n    for name, module in model.named_modules():\n        if isinstance(module, bn_types):\n            return True\n    return False\n\n\ndef concat_all_gather(tensor):\n    \"\"\"\n    Performs all_gather operation on the provided tensors.\n    *** Warning ***: torch.distributed.all_gather has no gradient.\n    \"\"\"\n    tensors_gather = [torch.ones_like(tensor)\n        for _ in range(torch.distributed.get_world_size())]\n    torch.distributed.all_gather(tensors_gather, tensor, async_op=False)\n\n    output = torch.cat(tensors_gather, dim=0)\n    return output\n\n\nclass PCA():\n    \"\"\"\n    Class to  compute and apply PCA.\n    \"\"\"\n    def __init__(self, dim=256, whit=0.5):\n        self.dim = dim\n        self.whit = whit\n        self.mean = None\n\n    def train_pca(self, cov):\n        \"\"\"\n        Takes a covariance matrix (np.ndarray) as input.\n        \"\"\"\n        d, v = np.linalg.eigh(cov)\n        eps = d.max() * 1e-5\n        n_0 = (d < eps).sum()\n        if n_0 > 0:\n            d[d < eps] = eps\n\n        # total energy\n        totenergy = d.sum()\n\n        # sort eigenvectors with eigenvalues order\n        idx = np.argsort(d)[::-1][:self.dim]\n        d = d[idx]\n        v = v[:, idx]\n\n        print(\"keeping %.2f %% of the energy\" % (d.sum() / totenergy * 100.0))\n\n        # for the whitening\n        d = np.diag(1. / d**self.whit)\n\n        # principal components\n        self.dvt = np.dot(d, v.T)\n\n    def apply(self, x):\n        # input is from numpy\n        if isinstance(x, np.ndarray):\n            if self.mean is not None:\n                x -= self.mean\n            return np.dot(self.dvt, x.T).T\n\n        # input is from torch and is on GPU\n        if x.is_cuda:\n            if self.mean is not None:\n                x -= torch.cuda.FloatTensor(self.mean)\n            return torch.mm(torch.cuda.FloatTensor(self.dvt), x.transpose(0, 1)).transpose(0, 1)\n\n        # input if from torch, on CPU\n        if self.mean is not None:\n            x -= torch.FloatTensor(self.mean)\n        return torch.mm(torch.FloatTensor(self.dvt), x.transpose(0, 1)).transpose(0, 1)\n\n\ndef compute_ap(ranks, nres):\n    \"\"\"\n    Computes average precision for given ranked indexes.\n    Arguments\n    ---------\n    ranks : zerro-based ranks of positive images\n    nres  : number of positive images\n    Returns\n    -------\n    ap    : average precision\n    \"\"\"\n\n    # number of images ranked by the system\n    nimgranks = len(ranks)\n\n    # accumulate trapezoids in PR-plot\n    ap = 0\n\n    recall_step = 1. / nres\n\n    for j in np.arange(nimgranks):\n        rank = ranks[j]\n\n        if rank == 0:\n            precision_0 = 1.\n        else:\n            precision_0 = float(j) / rank\n\n        precision_1 = float(j + 1) / (rank + 1)\n\n        ap += (precision_0 + precision_1) * recall_step / 2.\n\n    return ap\n\n\ndef compute_map(ranks, gnd, kappas=[]):\n    \"\"\"\n    Computes the mAP for a given set of returned results.\n         Usage:\n           map = compute_map (ranks, gnd)\n                 computes mean average precsion (map) only\n           map, aps, pr, prs = compute_map (ranks, gnd, kappas)\n                 computes mean average precision (map), average precision (aps) for each query\n                 computes mean precision at kappas (pr), precision at kappas (prs) for each query\n         Notes:\n         1) ranks starts from 0, ranks.shape = db_size X #queries\n         2) The junk results (e.g., the query itself) should be declared in the gnd stuct array\n         3) If there are no positive images for some query, that query is excluded from the evaluation\n    \"\"\"\n\n    map = 0.\n    nq = len(gnd) # number of queries\n    aps = np.zeros(nq)\n    pr = np.zeros(len(kappas))\n    prs = np.zeros((nq, len(kappas)))\n    nempty = 0\n\n    for i in np.arange(nq):\n        qgnd = np.array(gnd[i]['ok'])\n\n        # no positive images, skip from the average\n        if qgnd.shape[0] == 0:\n            aps[i] = float('nan')\n            prs[i, :] = float('nan')\n            nempty += 1\n            continue\n\n        try:\n            qgndj = np.array(gnd[i]['junk'])\n        except:\n            qgndj = np.empty(0)\n\n        # sorted positions of positive and junk images (0 based)\n        pos  = np.arange(ranks.shape[0])[np.in1d(ranks[:,i], qgnd)]\n        junk = np.arange(ranks.shape[0])[np.in1d(ranks[:,i], qgndj)]\n\n        k = 0;\n        ij = 0;\n        if len(junk):\n            # decrease positions of positives based on the number of\n            # junk images appearing before them\n            ip = 0\n            while (ip < len(pos)):\n                while (ij < len(junk) and pos[ip] > junk[ij]):\n                    k += 1\n                    ij += 1\n                pos[ip] = pos[ip] - k\n                ip += 1\n\n        # compute ap\n        ap = compute_ap(pos, len(qgnd))\n        map = map + ap\n        aps[i] = ap\n\n        # compute precision @ k\n        pos += 1 # get it to 1-based\n        for j in np.arange(len(kappas)):\n            kq = min(max(pos), kappas[j]); \n            prs[i, j] = (pos <= kq).sum() / kq\n        pr = pr + prs[i, :]\n\n    map = map / (nq - nempty)\n    pr = pr / (nq - nempty)\n\n    return map, aps, pr, prs"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/README.md",
    "content": "# ADE20k Semantic segmentation with CAE\n\n## Getting started \n\n1. Install the [mmsegmentation](https://github.com/open-mmlab/mmsegmentation) library and some required packages.\n\n```bash\npip install mmcv-full==1.3.0 mmsegmentation==0.11.0\npip install scipy timm==0.3.2\n```\n\n2. Install [apex](https://github.com/NVIDIA/apex) for mixed-precision training\n\n```bash\ngit clone https://github.com/NVIDIA/apex\ncd apex\npip install -v --disable-pip-version-check --no-cache-dir --global-option=\"--cpp_ext\" --global-option=\"--cuda_ext\" ./\n```\n\n3. Follow the guide in [mmseg](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/dataset_prepare.md) to prepare the ADE20k dataset.\n\n\n## Fine-tuning\n\nCommand format:\n```\ntools/dist_train.sh <CONFIG_PATH> <NUM_GPUS>  --work-dir <SAVE_PATH> --seed 0  --deterministic --options model.pretrained=<PRETRAIN_CHECKPOINT_PATH>\n```\n\nFor example, using a CAE-base backbone with UperNet:\n```bash\nbash tools/dist_train.sh \\\n    configs_local/cae/upernet/upernet_cae_base_12_512_slide_160k_ade20k_pt_4e-4.py 8 \\\n    --work-dir /path/to/save --seed 0  --deterministic \\\n    --options model.pretrained=<PRETRAIN_CHECKPOINT_PATH>\n```\n\nMore config files can be found at [`configs_local/cae/upernet`](configs_local/cae/upernet).\n\n\n## Evaluation\n\nCommand format:\n```\ntools/dist_test.sh  <CONFIG_PATH> <CHECKPOINT_PATH> <NUM_GPUS> --eval mIoU\n```\n\nFor example, evaluate a CAE-base backbone with UperNet:\n\n```bash\nbash tools/dist_test.sh configs_local/cae/upernet/upernet_cae_base_12_512_slide_160k_ade20k_pt_4e-4.py \\ \n    <CHECKPOINT_PATH> 8 --eval mIoU\n```\n\nPlease note that, the evaluation will be automatically conducted during training.\n\n## Results (pretrined models are trained on ImageNet-1K without label)\n\n| Backbone | #Pretrained Epoch | mIoU | Config                                   |\n| -------- | ----------------- | ---- | ---------------------------------------- |\n| ViT-B    | 300               | 48.1 | [3e-4](./configs_local/cae/upernet/upernet_cae_base_12_512_slide_160k_ade20k_pt_3e-4.py) |\n| ViT-B    | 800               | 49.7 | [2e-4](./configs_local/cae/upernet/upernet_cae_base_12_512_slide_160k_ade20k_pt_2e-4.py) |\n| ViT-B    | 1600              | 50.3 | [1e-4](./configs_local/cae/upernet/upernet_cae_base_12_512_slide_160k_ade20k_pt_1e-4.py) |\n| ViT-L    | 1600              | 54.9 | [4e-5](./configs_local/cae/upernet/upernet_cae_large_24_512_slide_160k_ade20k_pt_decay095_4e-5_dp015.py) |\n\nWe find that, if the pretrained model is better, a smaller learning rate is more suitable. However, different learning rates will not lead to significantly different results. For example, 800-epoch pretrained ViT-B could obtain 49.6 mIoU (averaged from two runs) with lr=4e-4.\n\n## Acknowledgment \n\nThis code is built using the [mmsegmentation](https://github.com/open-mmlab/mmsegmentation) library, [Timm](https://github.com/rwightman/pytorch-image-models) library, the [Swin](https://github.com/microsoft/Swin-Transformer) repository, [XCiT](https://github.com/facebookresearch/xcit) and the [SETR](https://github.com/fudan-zvg/SETR) repository.\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/backbone/beit.py",
    "content": "# --------------------------------------------------------\n# BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254)\n# Github source: https://github.com/microsoft/unilm/tree/master/beit\n# Copyright (c) 2021 Microsoft\n# Licensed under The MIT License [see LICENSE for details]\n# By Hangbo Bao\n# Based on timm, mmseg, setr, xcit and swin code bases\n# https://github.com/rwightman/pytorch-image-models/tree/master/timm\n# https://github.com/fudan-zvg/SETR\n# https://github.com/facebookresearch/xcit/\n# https://github.com/microsoft/Swin-Transformer\n# --------------------------------------------------------'\nimport math\nimport torch\nfrom functools import partial\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.utils.checkpoint as checkpoint\n\nfrom timm.models.layers import drop_path, to_2tuple, trunc_normal_\n\nimport numpy as np\n\nfrom mmcv_custom import load_checkpoint\nfrom mmseg.utils import get_root_logger\nfrom mmseg.models.builder import BACKBONES\n\n\nclass DropPath(nn.Module):\n    \"\"\"Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).\n    \"\"\"\n    def __init__(self, drop_prob=None):\n        super(DropPath, self).__init__()\n        self.drop_prob = drop_prob\n\n    def forward(self, x):\n        return drop_path(x, self.drop_prob, self.training)\n    \n    def extra_repr(self) -> str:\n        return 'p={}'.format(self.drop_prob)\n\n\nclass Mlp(nn.Module):\n    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        self.fc1 = nn.Linear(in_features, hidden_features)\n        self.act = act_layer()\n        self.fc2 = nn.Linear(hidden_features, out_features)\n        self.drop = nn.Dropout(drop)\n\n    def forward(self, x):\n        x = self.fc1(x)\n        x = self.act(x)\n        # x = self.drop(x)\n        # commit this for the orignal BERT implement \n        x = self.fc2(x)\n        x = self.drop(x)\n        return x\n\n\nclass Attention(nn.Module):\n    def __init__(\n            self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,\n            proj_drop=0., window_size=None, attn_head_dim=None):\n        super().__init__()\n        self.num_heads = num_heads\n        head_dim = dim // num_heads\n        if attn_head_dim is not None:\n            head_dim = attn_head_dim\n        all_head_dim = head_dim * self.num_heads\n        # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights\n        self.scale = qk_scale or head_dim ** -0.5\n\n        self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)\n        if qkv_bias:\n            self.q_bias = nn.Parameter(torch.zeros(all_head_dim))\n            self.v_bias = nn.Parameter(torch.zeros(all_head_dim))\n        else:\n            self.q_bias = None\n            self.v_bias = None\n\n        if window_size:\n            self.window_size = window_size\n            self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3\n            self.relative_position_bias_table = nn.Parameter(\n                torch.zeros(self.num_relative_distance, num_heads))  # 2*Wh-1 * 2*Ww-1, nH\n            # cls to token & token 2 cls & cls to cls\n\n            # get pair-wise relative position index for each token inside the window\n            coords_h = torch.arange(window_size[0])\n            coords_w = torch.arange(window_size[1])\n            coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww\n            coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww\n            relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww\n            relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2\n            relative_coords[:, :, 0] += window_size[0] - 1  # shift to start from 0\n            relative_coords[:, :, 1] += window_size[1] - 1\n            relative_coords[:, :, 0] *= 2 * window_size[1] - 1\n            relative_position_index = \\\n                torch.zeros(size=(window_size[0] * window_size[1] + 1, ) * 2, dtype=relative_coords.dtype)\n            relative_position_index[1:, 1:] = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww\n            relative_position_index[0, 0:] = self.num_relative_distance - 3\n            relative_position_index[0:, 0] = self.num_relative_distance - 2\n            relative_position_index[0, 0] = self.num_relative_distance - 1\n\n            self.register_buffer(\"relative_position_index\", relative_position_index)\n\n            # trunc_normal_(self.relative_position_bias_table, std=.0)\n        else:\n            self.window_size = None\n            self.relative_position_bias_table = None\n            self.relative_position_index = None\n\n        self.attn_drop = nn.Dropout(attn_drop)\n        self.proj = nn.Linear(all_head_dim, dim)\n        self.proj_drop = nn.Dropout(proj_drop)\n\n    def forward(self, x, rel_pos_bias=None):\n        B, N, C = x.shape\n        qkv_bias = None\n        if self.q_bias is not None:\n            qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))\n        # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)\n        qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)\n        qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)\n        q, k, v = qkv[0], qkv[1], qkv[2]   # make torchscript happy (cannot use tensor as tuple)\n\n        q = q * self.scale\n        attn = (q @ k.transpose(-2, -1))\n\n        if self.relative_position_bias_table is not None:\n            relative_position_bias = \\\n                self.relative_position_bias_table[self.relative_position_index.view(-1)].view(\n                    self.window_size[0] * self.window_size[1] + 1,\n                    self.window_size[0] * self.window_size[1] + 1, -1)  # Wh*Ww,Wh*Ww,nH\n            relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww\n            attn = attn + relative_position_bias.unsqueeze(0)\n\n        if rel_pos_bias is not None:\n            attn = attn + rel_pos_bias\n        \n        attn = attn.softmax(dim=-1)\n        attn = self.attn_drop(attn)\n\n        x = (attn @ v).transpose(1, 2).reshape(B, N, -1)\n        x = self.proj(x)\n        x = self.proj_drop(x)\n        return x\n\n\nclass Block(nn.Module):\n\n    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,\n                 drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,\n                 window_size=None, attn_head_dim=None):\n        super().__init__()\n        self.norm1 = norm_layer(dim)\n        self.attn = Attention(\n            dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,\n            attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim)\n        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here\n        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()\n        self.norm2 = norm_layer(dim)\n        mlp_hidden_dim = int(dim * mlp_ratio)\n        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)\n\n        if init_values is not None:\n            self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)\n            self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)\n        else:\n            self.gamma_1, self.gamma_2 = None, None\n\n    def forward(self, x, rel_pos_bias=None):\n        if self.gamma_1 is None:\n            x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))\n            x = x + self.drop_path(self.mlp(self.norm2(x)))\n        else:\n            x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))\n            x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))\n        return x\n\n\nclass PatchEmbed(nn.Module):\n    \"\"\" Image to Patch Embedding\n    \"\"\"\n    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):\n        super().__init__()\n        img_size = to_2tuple(img_size)\n        patch_size = to_2tuple(patch_size)\n        num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])\n        self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])\n        self.img_size = img_size\n        self.patch_size = patch_size\n        self.num_patches = num_patches\n\n        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)\n\n    def forward(self, x, **kwargs):\n        B, C, H, W = x.shape\n        # FIXME look at relaxing size constraints\n        # assert H == self.img_size[0] and W == self.img_size[1], \\\n        #     f\"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).\"\n        x = self.proj(x)\n        Hp, Wp = x.shape[2], x.shape[3]\n\n        x = x.flatten(2).transpose(1, 2)\n        return x, (Hp, Wp)\n\n\nclass HybridEmbed(nn.Module):\n    \"\"\" CNN Feature Map Embedding\n    Extract feature map from CNN, flatten, project to embedding dim.\n    \"\"\"\n    def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768):\n        super().__init__()\n        assert isinstance(backbone, nn.Module)\n        img_size = to_2tuple(img_size)\n        self.img_size = img_size\n        self.backbone = backbone\n        if feature_size is None:\n            with torch.no_grad():\n                # FIXME this is hacky, but most reliable way of determining the exact dim of the output feature\n                # map for all networks, the feature metadata has reliable channel and stride info, but using\n                # stride to calc feature dim requires info about padding of each stage that isn't captured.\n                training = backbone.training\n                if training:\n                    backbone.eval()\n                o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1]\n                feature_size = o.shape[-2:]\n                feature_dim = o.shape[1]\n                backbone.train(training)\n        else:\n            feature_size = to_2tuple(feature_size)\n            feature_dim = self.backbone.feature_info.channels()[-1]\n        self.num_patches = feature_size[0] * feature_size[1]\n        self.proj = nn.Linear(feature_dim, embed_dim)\n\n    def forward(self, x):\n        x = self.backbone(x)[-1]\n        x = x.flatten(2).transpose(1, 2)\n        x = self.proj(x)\n        return x\n\n\nclass RelativePositionBias(nn.Module):\n\n    def __init__(self, window_size, num_heads):\n        super().__init__()\n        self.window_size = window_size\n        self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3\n        self.relative_position_bias_table = nn.Parameter(\n            torch.zeros(self.num_relative_distance, num_heads))  # 2*Wh-1 * 2*Ww-1, nH\n        # cls to token & token 2 cls & cls to cls\n\n        # get pair-wise relative position index for each token inside the window\n        coords_h = torch.arange(window_size[0])\n        coords_w = torch.arange(window_size[1])\n        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww\n        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww\n        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww\n        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2\n        relative_coords[:, :, 0] += window_size[0] - 1  # shift to start from 0\n        relative_coords[:, :, 1] += window_size[1] - 1\n        relative_coords[:, :, 0] *= 2 * window_size[1] - 1\n        relative_position_index = \\\n            torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)\n        relative_position_index[1:, 1:] = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww\n        relative_position_index[0, 0:] = self.num_relative_distance - 3\n        relative_position_index[0:, 0] = self.num_relative_distance - 2\n        relative_position_index[0, 0] = self.num_relative_distance - 1\n\n        self.register_buffer(\"relative_position_index\", relative_position_index)\n\n        # trunc_normal_(self.relative_position_bias_table, std=.02)\n\n    def forward(self):\n        relative_position_bias = \\\n            self.relative_position_bias_table[self.relative_position_index.view(-1)].view(\n                self.window_size[0] * self.window_size[1] + 1,\n                self.window_size[0] * self.window_size[1] + 1, -1)  # Wh*Ww,Wh*Ww,nH\n        return relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww\n\ndef get_sinusoid_encoding_table(n_position, d_hid, token=False):\n    ''' Sinusoid position encoding table '''\n\n    def get_position_angle_vec(position):\n        return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]\n\n    sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])\n    sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2])  # dim 2i\n    sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2])  # dim 2i+1\n\n    if token:\n        sinusoid_table = np.concatenate([sinusoid_table, np.zeros([1, d_hid])], dim=0)\n\n    return torch.FloatTensor(sinusoid_table).unsqueeze(0)\n\n@BACKBONES.register_module()\nclass BEiT(nn.Module):\n    \"\"\" Vision Transformer with support for patch or hybrid CNN input stage\n    \"\"\"\n    def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=80, embed_dim=768, depth=12,\n                 num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,\n                 drop_path_rate=0., hybrid_backbone=None, norm_layer=None, init_values=None, use_checkpoint=False, \n                 use_abs_pos_emb=True, use_rel_pos_bias=False, use_sincos_pos_embed=True, use_shared_rel_pos_bias=False,\n                 out_indices=[3, 5, 7, 11], out_with_norm=False):\n        super().__init__()\n        norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)\n        self.num_classes = num_classes\n        self.num_features = self.embed_dim = embed_dim  # num_features for consistency with other models\n\n        if hybrid_backbone is not None:\n            self.patch_embed = HybridEmbed(\n                hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim)\n        else:\n            self.patch_embed = PatchEmbed(\n                img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)\n        num_patches = self.patch_embed.num_patches\n        self.out_indices = out_indices\n\n        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))\n        # self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))\n        self.use_abs_pos_emb = use_abs_pos_emb\n\n        if use_abs_pos_emb:\n            self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))\n        else:\n            # self.pos_embed = None\n            # self.pos_embed = get_sinusoid_encoding_table(num_patches, embed_dim)\n            if use_sincos_pos_embed:\n                self.pos_embed = self.build_2d_sincos_position_embedding(embed_dim)\n            else:\n                self.pos_embed = None\n\n        self.pos_drop = nn.Dropout(p=drop_rate)\n\n        if use_shared_rel_pos_bias:\n            self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads)\n        else:\n            self.rel_pos_bias = None\n\n        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]  # stochastic depth decay rule\n        self.use_rel_pos_bias = use_rel_pos_bias\n        self.use_checkpoint = use_checkpoint\n        self.blocks = nn.ModuleList([\n            Block(\n                dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,\n                drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,\n                init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None)\n            for i in range(depth)])\n\n        if self.pos_embed is not None:\n            trunc_normal_(self.pos_embed, std=.02)\n        trunc_normal_(self.cls_token, std=.02)\n        # trunc_normal_(self.mask_token, std=.02)\n        self.out_indices = out_indices\n\n        if patch_size == 16:\n            self.fpn1 = nn.Sequential(\n                nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),\n                nn.SyncBatchNorm(embed_dim),\n                nn.GELU(),\n                nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),\n            )\n\n            self.fpn2 = nn.Sequential(\n                nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),\n            )\n\n            self.fpn3 = nn.Identity()\n\n            self.fpn4 = nn.MaxPool2d(kernel_size=2, stride=2)\n        elif patch_size == 8:\n            self.fpn1 = nn.Sequential(\n                nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),\n            )\n\n            self.fpn2 = nn.Identity()\n\n            self.fpn3 = nn.Sequential(\n                nn.MaxPool2d(kernel_size=2, stride=2),\n            )\n\n            self.fpn4 = nn.Sequential(\n                nn.MaxPool2d(kernel_size=4, stride=4),\n            )\n        \n        if not out_with_norm:\n            self.norm = nn.Identity()\n        else:\n            self.norm = norm_layer(embed_dim)\n\n        self.apply(self._init_weights)\n        self.fix_init_weight()\n\n    def build_2d_sincos_position_embedding(self, embed_dim=768, temperature=10000., decode=False):\n        h, w = self.patch_embed.patch_shape\n        grid_w = torch.arange(w, dtype=torch.float32)\n        grid_h = torch.arange(h, dtype=torch.float32)\n        grid_w, grid_h = torch.meshgrid(grid_w, grid_h)\n        assert embed_dim % 4 == 0, 'Embed dimension must be divisible by 4 for 2D sin-cos position embedding'\n        pos_dim = embed_dim // 4\n        omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim\n        omega = 1. / (temperature ** omega)\n        out_w = torch.einsum('m,d->md', [grid_w.flatten(), omega])\n        out_h = torch.einsum('m,d->md', [grid_h.flatten(), omega])\n        pos_emb = torch.cat([torch.sin(out_w), torch.cos(out_w), torch.sin(out_h), torch.cos(out_h)], dim=1)[None, :, :]\n\n        pe_token = torch.zeros([1, 1, embed_dim], dtype=torch.float32)\n        pos_embed = nn.Parameter(torch.cat([pe_token, pos_emb], dim=1))\n        pos_embed.requires_grad = False\n        return pos_embed\n\n    def fix_init_weight(self):\n        def rescale(param, layer_id):\n            param.div_(math.sqrt(2.0 * layer_id))\n\n        for layer_id, layer in enumerate(self.blocks):\n            rescale(layer.attn.proj.weight.data, layer_id + 1)\n            rescale(layer.mlp.fc2.weight.data, layer_id + 1)\n\n    def _init_weights(self, m):\n        if isinstance(m, nn.Linear):\n            trunc_normal_(m.weight, std=.02)\n            if isinstance(m, nn.Linear) and m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n        elif isinstance(m, nn.LayerNorm):\n            nn.init.constant_(m.bias, 0)\n            nn.init.constant_(m.weight, 1.0)\n\n    def init_weights(self, pretrained=None):\n        \"\"\"Initialize the weights in backbone.\n\n        Args:\n            pretrained (str, optional): Path to pre-trained weights.\n                Defaults to None.\n        \"\"\"\n\n        def _init_weights(m):\n            if isinstance(m, nn.Linear):\n                trunc_normal_(m.weight, std=.02)\n                if isinstance(m, nn.Linear) and m.bias is not None:\n                    nn.init.constant_(m.bias, 0)\n            elif isinstance(m, nn.LayerNorm):\n                nn.init.constant_(m.bias, 0)\n                nn.init.constant_(m.weight, 1.0)\n\n        if isinstance(pretrained, str):\n            self.apply(_init_weights)\n            logger = get_root_logger()\n            load_checkpoint(self, pretrained, strict=False, logger=logger)\n        elif pretrained is None:\n            self.apply(_init_weights)\n        else:\n            raise TypeError('pretrained must be a str or None')\n\n    def get_num_layers(self):\n        return len(self.blocks)\n\n    @torch.jit.ignore\n    def no_weight_decay(self):\n        return {'pos_embed', 'cls_token'}\n\n    def forward_features(self, x):\n        B, C, H, W = x.shape\n        x, (Hp, Wp) = self.patch_embed(x)\n        batch_size, seq_len, _ = x.size()\n\n        cls_tokens = self.cls_token.expand(batch_size, -1, -1)  # stole cls_tokens impl from Phil Wang, thanks\n        x = torch.cat((cls_tokens, x), dim=1)\n        if self.pos_embed is not None:\n            '''\n            if self.use_abs_pos_emb:\n                x = x + self.pos_embed.expand(batch_size, -1, -1).type_as(x).to(x.device).clone().detach()\n            else:\n                x = x[:,1:] + self.pos_embed.expand(batch_size, -1, -1).type_as(x[:,1:]).to(x.device).clone().detach()\n                x = torch.cat([x[:,:1],x],dim=1)\n            '''\n            x = x + self.pos_embed\n        x = self.pos_drop(x)\n\n        rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None\n        features = []\n        for i, blk in enumerate(self.blocks):\n            if self.use_checkpoint:\n                x = checkpoint.checkpoint(blk, x, rel_pos_bias)\n            else:\n                x = blk(x, rel_pos_bias)\n            if i in self.out_indices:\n                xp = self.norm(x[:, 1:, :]).permute(0, 2, 1).reshape(B, -1, Hp, Wp)   \n                # xp = x[:, 1:, :].permute(0, 2, 1).reshape(B, -1, Hp, Wp)\n                features.append(xp.contiguous())\n\n        ops = [self.fpn1, self.fpn2, self.fpn3, self.fpn4]\n        for i in range(len(features)):\n            features[i] = ops[i](features[i])\n\n        return tuple(features)\n\n    def forward(self, x):\n        x = self.forward_features(x)\n        return x\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/backbone/beit_fapn.py",
    "content": "# --------------------------------------------------------\n# BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254)\n# Github source: https://github.com/microsoft/unilm/tree/master/beit\n# Copyright (c) 2021 Microsoft\n# Licensed under The MIT License [see LICENSE for details]\n# By Hangbo Bao\n# Based on timm, mmseg, setr, xcit and swin code bases\n# https://github.com/rwightman/pytorch-image-models/tree/master/timm\n# https://github.com/fudan-zvg/SETR\n# https://github.com/facebookresearch/xcit/\n# https://github.com/microsoft/Swin-Transformer\n# --------------------------------------------------------'\nimport math\nimport torch\nfrom functools import partial\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.utils.checkpoint as checkpoint\n\nfrom timm.models.layers import drop_path, to_2tuple, trunc_normal_\n\nfrom mmcv_custom import load_checkpoint\nfrom mmseg.utils import get_root_logger\nfrom mmseg.models.builder import BACKBONES\nfrom mmcv.ops import DeformConv2d\nfrom mmcv.cnn import xavier_init\n\nclass DropPath(nn.Module):\n    \"\"\"Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).\n    \"\"\"\n    def __init__(self, drop_prob=None):\n        super(DropPath, self).__init__()\n        self.drop_prob = drop_prob\n\n    def forward(self, x):\n        return drop_path(x, self.drop_prob, self.training)\n    \n    def extra_repr(self) -> str:\n        return 'p={}'.format(self.drop_prob)\n\n\nclass Mlp(nn.Module):\n    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        self.fc1 = nn.Linear(in_features, hidden_features)\n        self.act = act_layer()\n        self.fc2 = nn.Linear(hidden_features, out_features)\n        self.drop = nn.Dropout(drop)\n\n    def forward(self, x):\n        x = self.fc1(x)\n        x = self.act(x)\n        # x = self.drop(x)\n        # commit this for the orignal BERT implement \n        x = self.fc2(x)\n        x = self.drop(x)\n        return x\n\n\nclass Attention(nn.Module):\n    def __init__(\n            self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,\n            proj_drop=0., window_size=None, attn_head_dim=None):\n        super().__init__()\n        self.num_heads = num_heads\n        head_dim = dim // num_heads\n        if attn_head_dim is not None:\n            head_dim = attn_head_dim\n        all_head_dim = head_dim * self.num_heads\n        # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights\n        self.scale = qk_scale or head_dim ** -0.5\n\n        self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)\n        if qkv_bias:\n            self.q_bias = nn.Parameter(torch.zeros(all_head_dim))\n            self.v_bias = nn.Parameter(torch.zeros(all_head_dim))\n        else:\n            self.q_bias = None\n            self.v_bias = None\n\n        if window_size:\n            self.window_size = window_size\n            self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3\n            self.relative_position_bias_table = nn.Parameter(\n                torch.zeros(self.num_relative_distance, num_heads))  # 2*Wh-1 * 2*Ww-1, nH\n            # cls to token & token 2 cls & cls to cls\n\n            # get pair-wise relative position index for each token inside the window\n            coords_h = torch.arange(window_size[0])\n            coords_w = torch.arange(window_size[1])\n            coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww\n            coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww\n            relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww\n            relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2\n            relative_coords[:, :, 0] += window_size[0] - 1  # shift to start from 0\n            relative_coords[:, :, 1] += window_size[1] - 1\n            relative_coords[:, :, 0] *= 2 * window_size[1] - 1\n            relative_position_index = \\\n                torch.zeros(size=(window_size[0] * window_size[1] + 1, ) * 2, dtype=relative_coords.dtype)\n            relative_position_index[1:, 1:] = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww\n            relative_position_index[0, 0:] = self.num_relative_distance - 3\n            relative_position_index[0:, 0] = self.num_relative_distance - 2\n            relative_position_index[0, 0] = self.num_relative_distance - 1\n\n            self.register_buffer(\"relative_position_index\", relative_position_index)\n\n            # trunc_normal_(self.relative_position_bias_table, std=.0)\n        else:\n            self.window_size = None\n            self.relative_position_bias_table = None\n            self.relative_position_index = None\n\n        self.attn_drop = nn.Dropout(attn_drop)\n        self.proj = nn.Linear(all_head_dim, dim)\n        self.proj_drop = nn.Dropout(proj_drop)\n\n    def forward(self, x, rel_pos_bias=None):\n        B, N, C = x.shape\n        qkv_bias = None\n        if self.q_bias is not None:\n            qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))\n        # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)\n        qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)\n        qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)\n        q, k, v = qkv[0], qkv[1], qkv[2]   # make torchscript happy (cannot use tensor as tuple)\n\n        q = q * self.scale\n        attn = (q @ k.transpose(-2, -1))\n\n        if self.relative_position_bias_table is not None:\n            relative_position_bias = \\\n                self.relative_position_bias_table[self.relative_position_index.view(-1)].view(\n                    self.window_size[0] * self.window_size[1] + 1,\n                    self.window_size[0] * self.window_size[1] + 1, -1)  # Wh*Ww,Wh*Ww,nH\n            relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww\n            attn = attn + relative_position_bias.unsqueeze(0)\n\n        if rel_pos_bias is not None:\n            attn = attn + rel_pos_bias\n        \n        attn = attn.softmax(dim=-1)\n        attn = self.attn_drop(attn)\n\n        x = (attn @ v).transpose(1, 2).reshape(B, N, -1)\n        x = self.proj(x)\n        x = self.proj_drop(x)\n        return x\n\n\nclass Block(nn.Module):\n\n    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,\n                 drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,\n                 window_size=None, attn_head_dim=None):\n        super().__init__()\n        self.norm1 = norm_layer(dim)\n        self.attn = Attention(\n            dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,\n            attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim)\n        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here\n        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()\n        self.norm2 = norm_layer(dim)\n        mlp_hidden_dim = int(dim * mlp_ratio)\n        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)\n\n        if init_values is not None:\n            self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)\n            self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)\n        else:\n            self.gamma_1, self.gamma_2 = None, None\n\n    def forward(self, x, rel_pos_bias=None):\n        if self.gamma_1 is None:\n            x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))\n            x = x + self.drop_path(self.mlp(self.norm2(x)))\n        else:\n            x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))\n            x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))\n        return x\n\n\nclass PatchEmbed(nn.Module):\n    \"\"\" Image to Patch Embedding\n    \"\"\"\n    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):\n        super().__init__()\n        img_size = to_2tuple(img_size)\n        patch_size = to_2tuple(patch_size)\n        num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])\n        self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])\n        self.img_size = img_size\n        self.patch_size = patch_size\n        self.num_patches = num_patches\n\n        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)\n\n    def forward(self, x, **kwargs):\n        B, C, H, W = x.shape\n        # FIXME look at relaxing size constraints\n        # assert H == self.img_size[0] and W == self.img_size[1], \\\n        #     f\"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).\"\n        x = self.proj(x)\n        Hp, Wp = x.shape[2], x.shape[3]\n\n        x = x.flatten(2).transpose(1, 2)\n        return x, (Hp, Wp)\n\n\nclass HybridEmbed(nn.Module):\n    \"\"\" CNN Feature Map Embedding\n    Extract feature map from CNN, flatten, project to embedding dim.\n    \"\"\"\n    def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768):\n        super().__init__()\n        assert isinstance(backbone, nn.Module)\n        img_size = to_2tuple(img_size)\n        self.img_size = img_size\n        self.backbone = backbone\n        if feature_size is None:\n            with torch.no_grad():\n                # FIXME this is hacky, but most reliable way of determining the exact dim of the output feature\n                # map for all networks, the feature metadata has reliable channel and stride info, but using\n                # stride to calc feature dim requires info about padding of each stage that isn't captured.\n                training = backbone.training\n                if training:\n                    backbone.eval()\n                o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1]\n                feature_size = o.shape[-2:]\n                feature_dim = o.shape[1]\n                backbone.train(training)\n        else:\n            feature_size = to_2tuple(feature_size)\n            feature_dim = self.backbone.feature_info.channels()[-1]\n        self.num_patches = feature_size[0] * feature_size[1]\n        self.proj = nn.Linear(feature_dim, embed_dim)\n\n    def forward(self, x):\n        x = self.backbone(x)[-1]\n        x = x.flatten(2).transpose(1, 2)\n        x = self.proj(x)\n        return x\n\n\nclass RelativePositionBias(nn.Module):\n\n    def __init__(self, window_size, num_heads):\n        super().__init__()\n        self.window_size = window_size\n        self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3\n        self.relative_position_bias_table = nn.Parameter(\n            torch.zeros(self.num_relative_distance, num_heads))  # 2*Wh-1 * 2*Ww-1, nH\n        # cls to token & token 2 cls & cls to cls\n\n        # get pair-wise relative position index for each token inside the window\n        coords_h = torch.arange(window_size[0])\n        coords_w = torch.arange(window_size[1])\n        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww\n        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww\n        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww\n        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2\n        relative_coords[:, :, 0] += window_size[0] - 1  # shift to start from 0\n        relative_coords[:, :, 1] += window_size[1] - 1\n        relative_coords[:, :, 0] *= 2 * window_size[1] - 1\n        relative_position_index = \\\n            torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)\n        relative_position_index[1:, 1:] = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww\n        relative_position_index[0, 0:] = self.num_relative_distance - 3\n        relative_position_index[0:, 0] = self.num_relative_distance - 2\n        relative_position_index[0, 0] = self.num_relative_distance - 1\n\n        self.register_buffer(\"relative_position_index\", relative_position_index)\n\n        # trunc_normal_(self.relative_position_bias_table, std=.02)\n\n    def forward(self):\n        relative_position_bias = \\\n            self.relative_position_bias_table[self.relative_position_index.view(-1)].view(\n                self.window_size[0] * self.window_size[1] + 1,\n                self.window_size[0] * self.window_size[1] + 1, -1)  # Wh*Ww,Wh*Ww,nH\n        return relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww\n\nclass FeatureSelectionModule(nn.Module):\n    def __init__(self, in_c, out_c, norm=\"GM\"):\n        super(FeatureSelectionModule, self).__init__()\n        self.conv_attn = nn.Sequential(\n                nn.Conv2d(in_c, in_c, kernel_size=1, stride=1, bias=False).cuda(), # without norm and activation\n        )\n        self.sigmoid = nn.Sigmoid().cuda()\n        self.conv = nn.Sequential(\n                nn.Conv2d(in_c, out_c, kernel_size=1, stride=1, bias=False).cuda(),\n                nn.BatchNorm2d(out_c).cuda(),\n                nn.ReLU(inplace=True).cuda()\n        )\n        xavier_init(self.conv_attn)\n        for m in self.conv:       \n            if isinstance(m, nn.Conv2d):\n                xavier_init(m, distribution='uniform')\n    def forward(self, x):\n        attn = self.sigmoid(self.conv_attn(F.avg_pool2d(x, x.size()[2:])))\n        feat = torch.mul(x, attn)\n        x = x + feat\n        feat = self.conv(x)\n        return feat\n\nclass FeatureAlign(nn.Module):\n    def __init__(self, in_c, out_c, norm=None):\n        super(FeatureAlign, self).__init__()\n        self.lateral_conv = FeatureSelectionModule(in_c, out_c, norm=\"\")\n        self.relu = nn.ReLU(inplace=True).cuda()\n        self.offset = nn.Conv2d(out_c*2, 144, kernel_size=1, stride=1, padding=0, bias=False).cuda() # 144=kernel_size[0]*kernel_size[1]*deform_groups*2\n        self.deform_conv2d = DeformConv2d(out_c, out_c, kernel_size=3, stride=1, padding=1, dilation=1, deform_groups=8).cuda()\n    def forward(self, feat_l, feat_s):\n        feat_l = feat_l.float()\n        feat_s = feat_s.float()\n        HW = feat_l.size()[2:]\n        if feat_l.size()[2:] != feat_s.size()[2:]:\n            feat_up = F.interpolate(feat_s, HW, mode='bilinear', align_corners=False)\n        else:\n            feat_up = feat_s\n        feat_arm = self.lateral_conv(feat_l)\n        offset = self.offset(torch.cat([feat_arm, feat_up], dim=1)).float()\n        feat_align = self.relu(self.deform_conv2d(feat_up, offset))\n        return feat_align + feat_arm\n\n@BACKBONES.register_module()\nclass BEiT_FaPN(nn.Module):\n    \"\"\" Vision Transformer with support for patch or hybrid CNN input stage\n    \"\"\"\n    def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=80, embed_dim=768, depth=12,\n                 num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,\n                 drop_path_rate=0., hybrid_backbone=None, norm_layer=None, init_values=None, use_checkpoint=False, \n                 use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False,\n                 out_indices=[3, 5, 7, 11]):\n        super().__init__()\n        norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)\n        self.num_classes = num_classes\n        self.num_features = self.embed_dim = embed_dim  # num_features for consistency with other models\n        self.out_conv = []\n        self.FaPN = []\n        for i in range(len(out_indices)):\n            self.out_conv.append(nn.Conv2d(embed_dim, embed_dim, kernel_size=3, stride=1, padding=1, bias=False).cuda())\n            self.FaPN.append(FeatureAlign(embed_dim, embed_dim))\n        if hybrid_backbone is not None:\n            self.patch_embed = HybridEmbed(\n                hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim)\n        else:\n            self.patch_embed = PatchEmbed(\n                img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)\n        num_patches = self.patch_embed.num_patches\n        self.out_indices = out_indices\n\n        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))\n        # self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))\n        if use_abs_pos_emb:\n            self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))\n        else:\n            self.pos_embed = None\n        self.pos_drop = nn.Dropout(p=drop_rate)\n\n        if use_shared_rel_pos_bias:\n            self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads)\n        else:\n            self.rel_pos_bias = None\n\n        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]  # stochastic depth decay rule\n        self.use_rel_pos_bias = use_rel_pos_bias\n        self.use_checkpoint = use_checkpoint\n        self.blocks = nn.ModuleList([\n            Block(\n                dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,\n                drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,\n                init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None)\n            for i in range(depth)])\n\n        if self.pos_embed is not None:\n            trunc_normal_(self.pos_embed, std=.02)\n        trunc_normal_(self.cls_token, std=.02)\n        # trunc_normal_(self.mask_token, std=.02)\n        self.out_indices = out_indices\n\n        if patch_size == 16:\n            self.fpn1 = nn.Sequential(\n                nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),\n                nn.SyncBatchNorm(embed_dim),\n                nn.GELU(),\n                nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),\n            )\n\n            self.fpn2 = nn.Sequential(\n                nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),\n            )\n\n            self.fpn3 = nn.Identity()\n\n            self.fpn4 = nn.MaxPool2d(kernel_size=2, stride=2)\n        elif patch_size == 8:\n            self.fpn1 = nn.Sequential(\n                nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),\n            )\n\n            self.fpn2 = nn.Identity()\n\n            self.fpn3 = nn.Sequential(\n                nn.MaxPool2d(kernel_size=2, stride=2),\n            )\n\n            self.fpn4 = nn.Sequential(\n                nn.MaxPool2d(kernel_size=4, stride=4),\n            )\n        self.apply(self._init_weights)\n        self.fix_init_weight()\n\n    def fix_init_weight(self):\n        def rescale(param, layer_id):\n            param.div_(math.sqrt(2.0 * layer_id))\n\n        for layer_id, layer in enumerate(self.blocks):\n            rescale(layer.attn.proj.weight.data, layer_id + 1)\n            rescale(layer.mlp.fc2.weight.data, layer_id + 1)\n\n    def _init_weights(self, m):\n        if isinstance(m, nn.Linear):\n            trunc_normal_(m.weight, std=.02)\n            if isinstance(m, nn.Linear) and m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n        elif isinstance(m, nn.LayerNorm):\n            nn.init.constant_(m.bias, 0)\n            nn.init.constant_(m.weight, 1.0)\n\n    def init_weights(self, pretrained=None):\n        \"\"\"Initialize the weights in backbone.\n\n        Args:\n            pretrained (str, optional): Path to pre-trained weights.\n                Defaults to None.\n        \"\"\"\n\n        def _init_weights(m):\n            if isinstance(m, nn.Linear):\n                trunc_normal_(m.weight, std=.02)\n                if isinstance(m, nn.Linear) and m.bias is not None:\n                    nn.init.constant_(m.bias, 0)\n            elif isinstance(m, nn.LayerNorm):\n                nn.init.constant_(m.bias, 0)\n                nn.init.constant_(m.weight, 1.0)\n\n        if isinstance(pretrained, str):\n            self.apply(_init_weights)\n            logger = get_root_logger()\n            load_checkpoint(self, pretrained, strict=False, logger=logger)\n        elif pretrained is None:\n            self.apply(_init_weights)\n        else:\n            raise TypeError('pretrained must be a str or None')\n\n    def get_num_layers(self):\n        return len(self.blocks)\n\n    @torch.jit.ignore\n    def no_weight_decay(self):\n        return {'pos_embed', 'cls_token'}\n\n    def forward_features(self, x):\n        B, C, H, W = x.shape\n        x, (Hp, Wp) = self.patch_embed(x)\n        batch_size, seq_len, _ = x.size()\n\n        cls_tokens = self.cls_token.expand(batch_size, -1, -1)  # stole cls_tokens impl from Phil Wang, thanks\n        x = torch.cat((cls_tokens, x), dim=1)\n        if self.pos_embed is not None:\n            x = x + self.pos_embed\n        x = self.pos_drop(x)\n\n        rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None\n        features = []\n        for i, blk in enumerate(self.blocks):\n            if self.use_checkpoint:\n                x = checkpoint.checkpoint(blk, x, rel_pos_bias)\n            else:\n                x = blk(x, rel_pos_bias)\n            if i in self.out_indices:\n                xp = x[:, 1:, :].permute(0, 2, 1).reshape(B, -1, Hp, Wp)\n                features.append(xp.contiguous())\n\n        ops = [self.fpn1, self.fpn2, self.fpn3, self.fpn4]\n        for i in range(len(features)):\n            features[i] = ops[i](features[i])\n        features_fapn = []\n        features_fapn.append(features[-1])\n        for i in range(len(ops)-1, 0, -1):\n            new_feature = self.FaPN[i](features[i-1], features[i])\n            new_feature = self.out_conv[i](new_feature)\n            features_fapn.append(new_feature)\n        features_fapn.reverse()\n        return tuple(features_fapn)\n\n    def forward(self, x):\n        x = self.forward_features(x)\n        return x\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/backbone/cae.py",
    "content": "# --------------------------------------------------------\n# BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254)\n# Github source: https://github.com/microsoft/unilm/tree/master/beit\n# Copyright (c) 2021 Microsoft\n# Licensed under The MIT License [see LICENSE for details]\n# By Hangbo Bao\n# Based on timm, mmseg, setr, xcit and swin code bases\n# https://github.com/rwightman/pytorch-image-models/tree/master/timm\n# https://github.com/fudan-zvg/SETR\n# https://github.com/facebookresearch/xcit/\n# https://github.com/microsoft/Swin-Transformer\n# --------------------------------------------------------'\nimport math\nimport torch\nfrom functools import partial\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.utils.checkpoint as checkpoint\n\nfrom timm.models.layers import drop_path, to_2tuple, trunc_normal_\n\nimport numpy as np\n\nfrom mmcv_custom import load_checkpoint\nfrom mmseg.utils import get_root_logger\nfrom mmseg.models.builder import BACKBONES\n\n\nclass DropPath(nn.Module):\n    \"\"\"Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).\n    \"\"\"\n    def __init__(self, drop_prob=None):\n        super(DropPath, self).__init__()\n        self.drop_prob = drop_prob\n\n    def forward(self, x):\n        return drop_path(x, self.drop_prob, self.training)\n    \n    def extra_repr(self) -> str:\n        return 'p={}'.format(self.drop_prob)\n\n\nclass Mlp(nn.Module):\n    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        self.fc1 = nn.Linear(in_features, hidden_features)\n        self.act = act_layer()\n        self.fc2 = nn.Linear(hidden_features, out_features)\n        self.drop = nn.Dropout(drop)\n\n    def forward(self, x):\n        x = self.fc1(x)\n        x = self.act(x)\n        # x = self.drop(x)\n        # commit this for the orignal BERT implement \n        x = self.fc2(x)\n        x = self.drop(x)\n        return x\n\n\nclass Attention(nn.Module):\n    def __init__(\n            self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,\n            proj_drop=0., window_size=None, attn_head_dim=None):\n        super().__init__()\n        self.num_heads = num_heads\n        head_dim = dim // num_heads\n        if attn_head_dim is not None:\n            head_dim = attn_head_dim\n        all_head_dim = head_dim * self.num_heads\n        # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights\n        self.scale = qk_scale or head_dim ** -0.5\n\n        self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)\n        if qkv_bias:\n            self.q_bias = nn.Parameter(torch.zeros(all_head_dim))\n            self.v_bias = nn.Parameter(torch.zeros(all_head_dim))\n        else:\n            self.q_bias = None\n            self.v_bias = None\n\n        if window_size:\n            self.window_size = window_size\n            self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3\n            self.relative_position_bias_table = nn.Parameter(\n                torch.zeros(self.num_relative_distance, num_heads))  # 2*Wh-1 * 2*Ww-1, nH\n            # cls to token & token 2 cls & cls to cls\n\n            # get pair-wise relative position index for each token inside the window\n            coords_h = torch.arange(window_size[0])\n            coords_w = torch.arange(window_size[1])\n            coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww\n            coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww\n            relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww\n            relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2\n            relative_coords[:, :, 0] += window_size[0] - 1  # shift to start from 0\n            relative_coords[:, :, 1] += window_size[1] - 1\n            relative_coords[:, :, 0] *= 2 * window_size[1] - 1\n            relative_position_index = \\\n                torch.zeros(size=(window_size[0] * window_size[1] + 1, ) * 2, dtype=relative_coords.dtype)\n            relative_position_index[1:, 1:] = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww\n            relative_position_index[0, 0:] = self.num_relative_distance - 3\n            relative_position_index[0:, 0] = self.num_relative_distance - 2\n            relative_position_index[0, 0] = self.num_relative_distance - 1\n\n            self.register_buffer(\"relative_position_index\", relative_position_index)\n\n            # trunc_normal_(self.relative_position_bias_table, std=.0)\n        else:\n            self.window_size = None\n            self.relative_position_bias_table = None\n            self.relative_position_index = None\n\n        self.attn_drop = nn.Dropout(attn_drop)\n        self.proj = nn.Linear(all_head_dim, dim)\n        self.proj_drop = nn.Dropout(proj_drop)\n\n    def forward(self, x, rel_pos_bias=None):\n        B, N, C = x.shape\n        qkv_bias = None\n        if self.q_bias is not None:\n            qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))\n        # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)\n        qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)\n        qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)\n        q, k, v = qkv[0], qkv[1], qkv[2]   # make torchscript happy (cannot use tensor as tuple)\n\n        q = q * self.scale\n        attn = (q @ k.transpose(-2, -1))\n\n        if self.relative_position_bias_table is not None:\n            relative_position_bias = \\\n                self.relative_position_bias_table[self.relative_position_index.view(-1)].view(\n                    self.window_size[0] * self.window_size[1] + 1,\n                    self.window_size[0] * self.window_size[1] + 1, -1)  # Wh*Ww,Wh*Ww,nH\n            relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww\n            attn = attn + relative_position_bias.unsqueeze(0)\n\n        if rel_pos_bias is not None:\n            attn = attn + rel_pos_bias\n        \n        attn = attn.softmax(dim=-1)\n        attn = self.attn_drop(attn)\n\n        x = (attn @ v).transpose(1, 2).reshape(B, N, -1)\n        x = self.proj(x)\n        x = self.proj_drop(x)\n        return x\n\n\nclass Block(nn.Module):\n\n    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,\n                 drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,\n                 window_size=None, attn_head_dim=None):\n        super().__init__()\n        self.norm1 = norm_layer(dim)\n        self.attn = Attention(\n            dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,\n            attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim)\n        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here\n        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()\n        self.norm2 = norm_layer(dim)\n        mlp_hidden_dim = int(dim * mlp_ratio)\n        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)\n\n        if init_values is not None:\n            self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)\n            self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)\n        else:\n            self.gamma_1, self.gamma_2 = None, None\n\n    def forward(self, x, rel_pos_bias=None):\n        if self.gamma_1 is None:\n            x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))\n            x = x + self.drop_path(self.mlp(self.norm2(x)))\n        else:\n            x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))\n            x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))\n        return x\n\n\nclass PatchEmbed(nn.Module):\n    \"\"\" Image to Patch Embedding\n    \"\"\"\n    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):\n        super().__init__()\n        img_size = to_2tuple(img_size)\n        patch_size = to_2tuple(patch_size)\n        num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])\n        self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])\n        self.img_size = img_size\n        self.patch_size = patch_size\n        self.num_patches = num_patches\n\n        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)\n\n    def forward(self, x, **kwargs):\n        B, C, H, W = x.shape\n        # FIXME look at relaxing size constraints\n        # assert H == self.img_size[0] and W == self.img_size[1], \\\n        #     f\"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).\"\n        x = self.proj(x)\n        Hp, Wp = x.shape[2], x.shape[3]\n\n        x = x.flatten(2).transpose(1, 2)\n        return x, (Hp, Wp)\n\n\nclass HybridEmbed(nn.Module):\n    \"\"\" CNN Feature Map Embedding\n    Extract feature map from CNN, flatten, project to embedding dim.\n    \"\"\"\n    def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768):\n        super().__init__()\n        assert isinstance(backbone, nn.Module)\n        img_size = to_2tuple(img_size)\n        self.img_size = img_size\n        self.backbone = backbone\n        if feature_size is None:\n            with torch.no_grad():\n                # FIXME this is hacky, but most reliable way of determining the exact dim of the output feature\n                # map for all networks, the feature metadata has reliable channel and stride info, but using\n                # stride to calc feature dim requires info about padding of each stage that isn't captured.\n                training = backbone.training\n                if training:\n                    backbone.eval()\n                o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1]\n                feature_size = o.shape[-2:]\n                feature_dim = o.shape[1]\n                backbone.train(training)\n        else:\n            feature_size = to_2tuple(feature_size)\n            feature_dim = self.backbone.feature_info.channels()[-1]\n        self.num_patches = feature_size[0] * feature_size[1]\n        self.proj = nn.Linear(feature_dim, embed_dim)\n\n    def forward(self, x):\n        x = self.backbone(x)[-1]\n        x = x.flatten(2).transpose(1, 2)\n        x = self.proj(x)\n        return x\n\n\nclass RelativePositionBias(nn.Module):\n\n    def __init__(self, window_size, num_heads):\n        super().__init__()\n        self.window_size = window_size\n        self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3\n        self.relative_position_bias_table = nn.Parameter(\n            torch.zeros(self.num_relative_distance, num_heads))  # 2*Wh-1 * 2*Ww-1, nH\n        # cls to token & token 2 cls & cls to cls\n\n        # get pair-wise relative position index for each token inside the window\n        coords_h = torch.arange(window_size[0])\n        coords_w = torch.arange(window_size[1])\n        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww\n        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww\n        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww\n        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2\n        relative_coords[:, :, 0] += window_size[0] - 1  # shift to start from 0\n        relative_coords[:, :, 1] += window_size[1] - 1\n        relative_coords[:, :, 0] *= 2 * window_size[1] - 1\n        relative_position_index = \\\n            torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)\n        relative_position_index[1:, 1:] = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww\n        relative_position_index[0, 0:] = self.num_relative_distance - 3\n        relative_position_index[0:, 0] = self.num_relative_distance - 2\n        relative_position_index[0, 0] = self.num_relative_distance - 1\n\n        self.register_buffer(\"relative_position_index\", relative_position_index)\n\n        # trunc_normal_(self.relative_position_bias_table, std=.02)\n\n    def forward(self):\n        relative_position_bias = \\\n            self.relative_position_bias_table[self.relative_position_index.view(-1)].view(\n                self.window_size[0] * self.window_size[1] + 1,\n                self.window_size[0] * self.window_size[1] + 1, -1)  # Wh*Ww,Wh*Ww,nH\n        return relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww\n\ndef get_sinusoid_encoding_table(n_position, d_hid, token=False):\n    ''' Sinusoid position encoding table '''\n\n    def get_position_angle_vec(position):\n        return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]\n\n    sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])\n    sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2])  # dim 2i\n    sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2])  # dim 2i+1\n\n    if token:\n        sinusoid_table = np.concatenate([sinusoid_table, np.zeros([1, d_hid])], dim=0)\n\n    return torch.FloatTensor(sinusoid_table).unsqueeze(0)\n\n@BACKBONES.register_module()\nclass CAE(nn.Module):\n    \"\"\" Vision Transformer with support for patch or hybrid CNN input stage\n    \"\"\"\n    def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=80, embed_dim=768, depth=12,\n                 num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,\n                 drop_path_rate=0., hybrid_backbone=None, norm_layer=None, init_values=None, use_checkpoint=False, \n                 use_abs_pos_emb=True, use_rel_pos_bias=False, use_sincos_pos_embed=True, use_shared_rel_pos_bias=False,\n                 out_indices=[3, 5, 7, 11], out_with_norm=False):\n        super().__init__()\n        norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)\n        self.num_classes = num_classes\n        self.num_features = self.embed_dim = embed_dim  # num_features for consistency with other models\n\n        if hybrid_backbone is not None:\n            self.patch_embed = HybridEmbed(\n                hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim)\n        else:\n            self.patch_embed = PatchEmbed(\n                img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)\n        num_patches = self.patch_embed.num_patches\n        self.out_indices = out_indices\n\n        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))\n        # self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))\n        self.use_abs_pos_emb = use_abs_pos_emb\n\n        if use_abs_pos_emb:\n            self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))\n        else:\n            # self.pos_embed = None\n            # self.pos_embed = get_sinusoid_encoding_table(num_patches, embed_dim)\n            if use_sincos_pos_embed:\n                self.pos_embed = self.build_2d_sincos_position_embedding(embed_dim)\n            else:\n                self.pos_embed = None\n\n        self.pos_drop = nn.Dropout(p=drop_rate)\n\n        if use_shared_rel_pos_bias:\n            self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads)\n        else:\n            self.rel_pos_bias = None\n\n        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]  # stochastic depth decay rule\n        self.use_rel_pos_bias = use_rel_pos_bias\n        self.use_checkpoint = use_checkpoint\n        self.blocks = nn.ModuleList([\n            Block(\n                dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,\n                drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,\n                init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None)\n            for i in range(depth)])\n\n        if self.pos_embed is not None:\n            trunc_normal_(self.pos_embed, std=.02)\n        trunc_normal_(self.cls_token, std=.02)\n        # trunc_normal_(self.mask_token, std=.02)\n        self.out_indices = out_indices\n\n        if patch_size == 16:\n            self.fpn1 = nn.Sequential(\n                nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),\n                nn.SyncBatchNorm(embed_dim),\n                nn.GELU(),\n                nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),\n            )\n\n            self.fpn2 = nn.Sequential(\n                nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),\n            )\n\n            self.fpn3 = nn.Identity()\n\n            self.fpn4 = nn.MaxPool2d(kernel_size=2, stride=2)\n        elif patch_size == 8:\n            self.fpn1 = nn.Sequential(\n                nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),\n            )\n\n            self.fpn2 = nn.Identity()\n\n            self.fpn3 = nn.Sequential(\n                nn.MaxPool2d(kernel_size=2, stride=2),\n            )\n\n            self.fpn4 = nn.Sequential(\n                nn.MaxPool2d(kernel_size=4, stride=4),\n            )\n        \n        if not out_with_norm:\n            self.norm = nn.Identity()\n        else:\n            self.norm = norm_layer(embed_dim)\n\n        self.apply(self._init_weights)\n        self.fix_init_weight()\n\n    def build_2d_sincos_position_embedding(self, embed_dim=768, temperature=10000., decode=False):\n        h, w = self.patch_embed.patch_shape\n        grid_w = torch.arange(w, dtype=torch.float32)\n        grid_h = torch.arange(h, dtype=torch.float32)\n        grid_w, grid_h = torch.meshgrid(grid_w, grid_h)\n        assert embed_dim % 4 == 0, 'Embed dimension must be divisible by 4 for 2D sin-cos position embedding'\n        pos_dim = embed_dim // 4\n        omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim\n        omega = 1. / (temperature ** omega)\n        out_w = torch.einsum('m,d->md', [grid_w.flatten(), omega])\n        out_h = torch.einsum('m,d->md', [grid_h.flatten(), omega])\n        pos_emb = torch.cat([torch.sin(out_w), torch.cos(out_w), torch.sin(out_h), torch.cos(out_h)], dim=1)[None, :, :]\n\n        pe_token = torch.zeros([1, 1, embed_dim], dtype=torch.float32)\n        pos_embed = nn.Parameter(torch.cat([pe_token, pos_emb], dim=1))\n        pos_embed.requires_grad = False\n        return pos_embed\n\n    def fix_init_weight(self):\n        def rescale(param, layer_id):\n            param.div_(math.sqrt(2.0 * layer_id))\n\n        for layer_id, layer in enumerate(self.blocks):\n            rescale(layer.attn.proj.weight.data, layer_id + 1)\n            rescale(layer.mlp.fc2.weight.data, layer_id + 1)\n\n    def _init_weights(self, m):\n        if isinstance(m, nn.Linear):\n            trunc_normal_(m.weight, std=.02)\n            if isinstance(m, nn.Linear) and m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n        elif isinstance(m, nn.LayerNorm):\n            nn.init.constant_(m.bias, 0)\n            nn.init.constant_(m.weight, 1.0)\n\n    def init_weights(self, pretrained=None):\n        \"\"\"Initialize the weights in backbone.\n\n        Args:\n            pretrained (str, optional): Path to pre-trained weights.\n                Defaults to None.\n        \"\"\"\n\n        def _init_weights(m):\n            if isinstance(m, nn.Linear):\n                trunc_normal_(m.weight, std=.02)\n                if isinstance(m, nn.Linear) and m.bias is not None:\n                    nn.init.constant_(m.bias, 0)\n            elif isinstance(m, nn.LayerNorm):\n                nn.init.constant_(m.bias, 0)\n                nn.init.constant_(m.weight, 1.0)\n\n        if isinstance(pretrained, str):\n            self.apply(_init_weights)\n            logger = get_root_logger()\n            load_checkpoint(self, pretrained, strict=False, logger=logger)\n        elif pretrained is None:\n            self.apply(_init_weights)\n        else:\n            raise TypeError('pretrained must be a str or None')\n\n    def get_num_layers(self):\n        return len(self.blocks)\n\n    @torch.jit.ignore\n    def no_weight_decay(self):\n        return {'pos_embed', 'cls_token'}\n\n    def forward_features(self, x):\n        B, C, H, W = x.shape\n        x, (Hp, Wp) = self.patch_embed(x)\n        batch_size, seq_len, _ = x.size()\n\n        cls_tokens = self.cls_token.expand(batch_size, -1, -1)  # stole cls_tokens impl from Phil Wang, thanks\n        x = torch.cat((cls_tokens, x), dim=1)\n        if self.pos_embed is not None:\n            '''\n            if self.use_abs_pos_emb:\n                x = x + self.pos_embed.expand(batch_size, -1, -1).type_as(x).to(x.device).clone().detach()\n            else:\n                x = x[:,1:] + self.pos_embed.expand(batch_size, -1, -1).type_as(x[:,1:]).to(x.device).clone().detach()\n                x = torch.cat([x[:,:1],x],dim=1)\n            '''\n            x = x + self.pos_embed\n        x = self.pos_drop(x)\n\n        rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None\n        features = []\n        for i, blk in enumerate(self.blocks):\n            if self.use_checkpoint:\n                x = checkpoint.checkpoint(blk, x, rel_pos_bias)\n            else:\n                x = blk(x, rel_pos_bias)\n            if i in self.out_indices:\n                xp = self.norm(x[:, 1:, :]).permute(0, 2, 1).reshape(B, -1, Hp, Wp)   \n                # xp = x[:, 1:, :].permute(0, 2, 1).reshape(B, -1, Hp, Wp)\n                features.append(xp.contiguous())\n\n        ops = [self.fpn1, self.fpn2, self.fpn3, self.fpn4]\n        for i in range(len(features)):\n            features[i] = ops[i](features[i])\n\n        return tuple(features)\n\n    def forward(self, x):\n        x = self.forward_features(x)\n        return x\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/backbone/fapn.py",
    "content": "class FeatureSelectionModule(nn.Module):\n    def __init__(self, in_c, out_c, norm=\"GM\"):\n        super(FeatureSelectionModule, self).__init__()\n        self.conv_attn = nn.Sequential(\n                nn.Conv2d(in_c, in_c, kernel_size=1, stride=1, bias=False), # without norm and activation\n        )\n        self.sigmoid = nn.Sigmoid()\n        self.conv = nn.Sequential(\n                nn.Conv2d(in_c, out_c, kernel_size=1, stride=1, bias=False),\n                nn.BatchNorm2d(out_c),\n                nn.ReLU(inplace=True)\n        )\n        xavier_init(self.conv_attn)\n        for m in self.conv.modeuls():       \n            if isintance(m, nn.Conv2d):\n                xavier_init(m, distribution='uniform')\n        def forward(self, x):\n            attn = self.sigmoid(self.conv_attn(F.avg_pool2d(x, x.size()[2:])))\n            feat = torch.mul(x, attn)\n            x = x + feat\n            feat = self.conv(x)\n            return feat\n\nclass FeatureAlign(nn.Module):\n    def __init__(self, in_c, out_c, norm=None):\n        super(FeatureAlign, self).__init__()\n        self.lateral_conv = FeatureSelectionModule(in_c, out_c, norm=\"\")\n        self.relu = nn.ReLU(inplace=True)\n        self.offset = nn.Conv2d(out_c * 2, out_c, kernel_size=1, stride=1, padding=0, bias=False)\n        self.deform_conv2d = DeformConv2d(out_c, out_c, kernel_size=3, stride=1, padding=1, dilation=1, deform_groups=8)\n    def forward(self, teat_l, feat_s):\n        HW = feat_l.size()[2:]\n        if feat_l.size()[2:] != feat_s.size()[2:]:\n            feat_up = F.interpolate(feat_s, HW, mode='bilinear', align_corners=False)\n        else:\n            feat_up = feat_s\n        feat_arm = self.lateral_conv(feat_l)\n        offset = self.offset(torch.cat([feat_arm, feat_up], dim=1))\n        feat_align = self.relu(self.deform_conv2d(feat_up, offset))\n        return feat_align + feat_arm\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/backbone/mae.py",
    "content": "# --------------------------------------------------------\n# BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254)\n# Github source: https://github.com/microsoft/unilm/tree/master/beit\n# Copyright (c) 2021 Microsoft\n# Licensed under The MIT License [see LICENSE for details]\n# By Hangbo Bao\n# Based on timm, mmseg, setr, xcit and swin code bases\n# https://github.com/rwightman/pytorch-image-models/tree/master/timm\n# https://github.com/fudan-zvg/SETR\n# https://github.com/facebookresearch/xcit/\n# https://github.com/microsoft/Swin-Transformer\n# --------------------------------------------------------'\nimport math\nimport torch\nfrom functools import partial\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.utils.checkpoint as checkpoint\n\nfrom timm.models.layers import drop_path, to_2tuple, trunc_normal_\n\nimport numpy as np\n\nfrom mmcv_custom import load_checkpoint\nfrom mmseg.utils import get_root_logger\nfrom mmseg.models.builder import BACKBONES\n\n\nclass DropPath(nn.Module):\n    \"\"\"Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).\n    \"\"\"\n    def __init__(self, drop_prob=None):\n        super(DropPath, self).__init__()\n        self.drop_prob = drop_prob\n\n    def forward(self, x):\n        return drop_path(x, self.drop_prob, self.training)\n    \n    def extra_repr(self) -> str:\n        return 'p={}'.format(self.drop_prob)\n\n\nclass Mlp(nn.Module):\n    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        self.fc1 = nn.Linear(in_features, hidden_features)\n        self.act = act_layer()\n        self.fc2 = nn.Linear(hidden_features, out_features)\n        self.drop = nn.Dropout(drop)\n\n    def forward(self, x):\n        x = self.fc1(x)\n        x = self.act(x)\n        # x = self.drop(x)\n        # commit this for the orignal BERT implement \n        x = self.fc2(x)\n        x = self.drop(x)\n        return x\n\n\nclass Attention(nn.Module):\n    def __init__(\n            self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,\n            proj_drop=0., window_size=None, attn_head_dim=None):\n        super().__init__()\n        self.num_heads = num_heads\n        head_dim = dim // num_heads\n        if attn_head_dim is not None:\n            head_dim = attn_head_dim\n        all_head_dim = head_dim * self.num_heads\n        # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights\n        self.scale = qk_scale or head_dim ** -0.5\n\n        self.qkv = nn.Linear(dim, all_head_dim * 3, bias=True)\n        # if qkv_bias:\n        #     self.q_bias = nn.Parameter(torch.zeros(all_head_dim))\n        #     self.v_bias = nn.Parameter(torch.zeros(all_head_dim))\n        # else:\n        #     self.q_bias = None\n        #     self.v_bias = None\n\n        if window_size:\n            self.window_size = window_size\n            self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3\n            self.relative_position_bias_table = nn.Parameter(\n                torch.zeros(self.num_relative_distance, num_heads))  # 2*Wh-1 * 2*Ww-1, nH\n            # cls to token & token 2 cls & cls to cls\n\n            # get pair-wise relative position index for each token inside the window\n            coords_h = torch.arange(window_size[0])\n            coords_w = torch.arange(window_size[1])\n            coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww\n            coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww\n            relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww\n            relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2\n            relative_coords[:, :, 0] += window_size[0] - 1  # shift to start from 0\n            relative_coords[:, :, 1] += window_size[1] - 1\n            relative_coords[:, :, 0] *= 2 * window_size[1] - 1\n            relative_position_index = \\\n                torch.zeros(size=(window_size[0] * window_size[1] + 1, ) * 2, dtype=relative_coords.dtype)\n            relative_position_index[1:, 1:] = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww\n            relative_position_index[0, 0:] = self.num_relative_distance - 3\n            relative_position_index[0:, 0] = self.num_relative_distance - 2\n            relative_position_index[0, 0] = self.num_relative_distance - 1\n\n            self.register_buffer(\"relative_position_index\", relative_position_index)\n\n            # trunc_normal_(self.relative_position_bias_table, std=.0)\n        else:\n            self.window_size = None\n            self.relative_position_bias_table = None\n            self.relative_position_index = None\n\n        self.attn_drop = nn.Dropout(attn_drop)\n        self.proj = nn.Linear(all_head_dim, dim)\n        self.proj_drop = nn.Dropout(proj_drop)\n\n    def forward(self, x, rel_pos_bias=None):\n        B, N, C = x.shape\n        qkv_bias = None\n        # if self.q_bias is not None:\n        #     qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))\n        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)\n        # qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)\n        # qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)\n        q, k, v = qkv[0], qkv[1], qkv[2]   # make torchscript happy (cannot use tensor as tuple)\n\n        q = q * self.scale\n        attn = (q @ k.transpose(-2, -1))\n\n        if self.relative_position_bias_table is not None:\n            relative_position_bias = \\\n                self.relative_position_bias_table[self.relative_position_index.view(-1)].view(\n                    self.window_size[0] * self.window_size[1] + 1,\n                    self.window_size[0] * self.window_size[1] + 1, -1)  # Wh*Ww,Wh*Ww,nH\n            relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww\n            attn = attn + relative_position_bias.unsqueeze(0)\n\n        if rel_pos_bias is not None:\n            attn = attn + rel_pos_bias\n        \n        attn = attn.softmax(dim=-1)\n        attn = self.attn_drop(attn)\n\n        x = (attn @ v).transpose(1, 2).reshape(B, N, -1)\n        x = self.proj(x)\n        x = self.proj_drop(x)\n        return x\n\n\nclass Block(nn.Module):\n\n    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,\n                 drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,\n                 window_size=None, attn_head_dim=None):\n        super().__init__()\n        self.norm1 = norm_layer(dim)\n        self.attn = Attention(\n            dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,\n            attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim)\n        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here\n        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()\n        self.norm2 = norm_layer(dim)\n        mlp_hidden_dim = int(dim * mlp_ratio)\n        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)\n\n        if init_values is not None:\n            self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)\n            self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)\n        else:\n            self.gamma_1, self.gamma_2 = None, None\n\n    def forward(self, x, rel_pos_bias=None):\n        if self.gamma_1 is None:\n            x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))\n            x = x + self.drop_path(self.mlp(self.norm2(x)))\n        else:\n            x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))\n            x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))\n        return x\n\n\nclass PatchEmbed(nn.Module):\n    \"\"\" Image to Patch Embedding\n    \"\"\"\n    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):\n        super().__init__()\n        img_size = to_2tuple(img_size)\n        patch_size = to_2tuple(patch_size)\n        num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])\n        self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])\n        self.img_size = img_size\n        self.patch_size = patch_size\n        self.num_patches = num_patches\n\n        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)\n\n    def forward(self, x, **kwargs):\n        B, C, H, W = x.shape\n        # FIXME look at relaxing size constraints\n        # assert H == self.img_size[0] and W == self.img_size[1], \\\n        #     f\"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).\"\n        x = self.proj(x)\n        Hp, Wp = x.shape[2], x.shape[3]\n\n        x = x.flatten(2).transpose(1, 2)\n        return x, (Hp, Wp)\n\n\nclass HybridEmbed(nn.Module):\n    \"\"\" CNN Feature Map Embedding\n    Extract feature map from CNN, flatten, project to embedding dim.\n    \"\"\"\n    def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768):\n        super().__init__()\n        assert isinstance(backbone, nn.Module)\n        img_size = to_2tuple(img_size)\n        self.img_size = img_size\n        self.backbone = backbone\n        if feature_size is None:\n            with torch.no_grad():\n                # FIXME this is hacky, but most reliable way of determining the exact dim of the output feature\n                # map for all networks, the feature metadata has reliable channel and stride info, but using\n                # stride to calc feature dim requires info about padding of each stage that isn't captured.\n                training = backbone.training\n                if training:\n                    backbone.eval()\n                o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1]\n                feature_size = o.shape[-2:]\n                feature_dim = o.shape[1]\n                backbone.train(training)\n        else:\n            feature_size = to_2tuple(feature_size)\n            feature_dim = self.backbone.feature_info.channels()[-1]\n        self.num_patches = feature_size[0] * feature_size[1]\n        self.proj = nn.Linear(feature_dim, embed_dim)\n\n    def forward(self, x):\n        x = self.backbone(x)[-1]\n        x = x.flatten(2).transpose(1, 2)\n        x = self.proj(x)\n        return x\n\n\nclass RelativePositionBias(nn.Module):\n\n    def __init__(self, window_size, num_heads):\n        super().__init__()\n        self.window_size = window_size\n        self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3\n        self.relative_position_bias_table = nn.Parameter(\n            torch.zeros(self.num_relative_distance, num_heads))  # 2*Wh-1 * 2*Ww-1, nH\n        # cls to token & token 2 cls & cls to cls\n\n        # get pair-wise relative position index for each token inside the window\n        coords_h = torch.arange(window_size[0])\n        coords_w = torch.arange(window_size[1])\n        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww\n        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww\n        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww\n        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2\n        relative_coords[:, :, 0] += window_size[0] - 1  # shift to start from 0\n        relative_coords[:, :, 1] += window_size[1] - 1\n        relative_coords[:, :, 0] *= 2 * window_size[1] - 1\n        relative_position_index = \\\n            torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)\n        relative_position_index[1:, 1:] = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww\n        relative_position_index[0, 0:] = self.num_relative_distance - 3\n        relative_position_index[0:, 0] = self.num_relative_distance - 2\n        relative_position_index[0, 0] = self.num_relative_distance - 1\n\n        self.register_buffer(\"relative_position_index\", relative_position_index)\n\n        # trunc_normal_(self.relative_position_bias_table, std=.02)\n\n    def forward(self):\n        relative_position_bias = \\\n            self.relative_position_bias_table[self.relative_position_index.view(-1)].view(\n                self.window_size[0] * self.window_size[1] + 1,\n                self.window_size[0] * self.window_size[1] + 1, -1)  # Wh*Ww,Wh*Ww,nH\n        return relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww\n\ndef get_sinusoid_encoding_table(n_position, d_hid, token=False):\n    ''' Sinusoid position encoding table '''\n\n    def get_position_angle_vec(position):\n        return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]\n\n    sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])\n    sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2])  # dim 2i\n    sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2])  # dim 2i+1\n\n    if token:\n        sinusoid_table = np.concatenate([sinusoid_table, np.zeros([1, d_hid])], dim=0)\n\n    return torch.FloatTensor(sinusoid_table).unsqueeze(0)\n\n@BACKBONES.register_module()\nclass MAE(nn.Module):\n    \"\"\" Vision Transformer with support for patch or hybrid CNN input stage\n    \"\"\"\n    def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=80, embed_dim=768, depth=12,\n                 num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,\n                 drop_path_rate=0., hybrid_backbone=None, norm_layer=None, init_values=None, use_checkpoint=False, \n                 use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False,\n                 out_indices=[3, 5, 7, 11]):\n        super().__init__()\n        norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)\n        self.num_classes = num_classes\n        self.num_features = self.embed_dim = embed_dim  # num_features for consistency with other models\n\n        if hybrid_backbone is not None:\n            self.patch_embed = HybridEmbed(\n                hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim)\n        else:\n            self.patch_embed = PatchEmbed(\n                img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)\n        num_patches = self.patch_embed.num_patches\n        self.out_indices = out_indices\n\n        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))\n        # self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))\n        self.use_abs_pos_emb = use_abs_pos_emb\n\n        if use_abs_pos_emb:\n            self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))\n        else:\n            # self.pos_embed = None\n            # self.pos_embed = get_sinusoid_encoding_table(num_patches, embed_dim)\n            self.pos_embed = self.build_2d_sincos_position_embedding(embed_dim)\n\n        self.pos_drop = nn.Dropout(p=drop_rate)\n\n        if use_shared_rel_pos_bias:\n            self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads)\n        else:\n            self.rel_pos_bias = None\n\n        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]  # stochastic depth decay rule\n        self.use_rel_pos_bias = use_rel_pos_bias\n        self.use_checkpoint = use_checkpoint\n        self.blocks = nn.ModuleList([\n            Block(\n                dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,\n                drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,\n                init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None)\n            for i in range(depth)])\n\n        if self.pos_embed is not None:\n            trunc_normal_(self.pos_embed, std=.02)\n        trunc_normal_(self.cls_token, std=.02)\n        # trunc_normal_(self.mask_token, std=.02)\n        self.out_indices = out_indices\n\n        if patch_size == 16:\n            self.fpn1 = nn.Sequential(\n                nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),\n                nn.SyncBatchNorm(embed_dim),\n                nn.GELU(),\n                nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),\n            )\n\n            self.fpn2 = nn.Sequential(\n                nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),\n            )\n\n            self.fpn3 = nn.Identity()\n\n            self.fpn4 = nn.MaxPool2d(kernel_size=2, stride=2)\n        elif patch_size == 8:\n            self.fpn1 = nn.Sequential(\n                nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),\n            )\n\n            self.fpn2 = nn.Identity()\n\n            self.fpn3 = nn.Sequential(\n                nn.MaxPool2d(kernel_size=2, stride=2),\n            )\n\n            self.fpn4 = nn.Sequential(\n                nn.MaxPool2d(kernel_size=4, stride=4),\n            )\n        self.apply(self._init_weights)\n        self.fix_init_weight()\n\n    def build_2d_sincos_position_embedding(self, embed_dim=768, temperature=10000., decode=False):\n        h, w = self.patch_embed.patch_shape\n        grid_w = torch.arange(w, dtype=torch.float32)\n        grid_h = torch.arange(h, dtype=torch.float32)\n        grid_w, grid_h = torch.meshgrid(grid_w, grid_h)\n        assert embed_dim % 4 == 0, 'Embed dimension must be divisible by 4 for 2D sin-cos position embedding'\n        pos_dim = embed_dim // 4\n        omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim\n        omega = 1. / (temperature ** omega)\n        out_w = torch.einsum('m,d->md', [grid_w.flatten(), omega])\n        out_h = torch.einsum('m,d->md', [grid_h.flatten(), omega])\n        pos_emb = torch.cat([torch.sin(out_w), torch.cos(out_w), torch.sin(out_h), torch.cos(out_h)], dim=1)[None, :, :]\n\n        pe_token = torch.zeros([1, 1, embed_dim], dtype=torch.float32)\n        pos_embed = nn.Parameter(torch.cat([pe_token, pos_emb], dim=1))\n        pos_embed.requires_grad = False\n        return pos_embed\n\n    def fix_init_weight(self):\n        def rescale(param, layer_id):\n            param.div_(math.sqrt(2.0 * layer_id))\n\n        for layer_id, layer in enumerate(self.blocks):\n            rescale(layer.attn.proj.weight.data, layer_id + 1)\n            rescale(layer.mlp.fc2.weight.data, layer_id + 1)\n\n    def _init_weights(self, m):\n        if isinstance(m, nn.Linear):\n            trunc_normal_(m.weight, std=.02)\n            if isinstance(m, nn.Linear) and m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n        elif isinstance(m, nn.LayerNorm):\n            nn.init.constant_(m.bias, 0)\n            nn.init.constant_(m.weight, 1.0)\n\n    def init_weights(self, pretrained=None):\n        \"\"\"Initialize the weights in backbone.\n\n        Args:\n            pretrained (str, optional): Path to pre-trained weights.\n                Defaults to None.\n        \"\"\"\n\n        def _init_weights(m):\n            if isinstance(m, nn.Linear):\n                trunc_normal_(m.weight, std=.02)\n                if isinstance(m, nn.Linear) and m.bias is not None:\n                    nn.init.constant_(m.bias, 0)\n            elif isinstance(m, nn.LayerNorm):\n                nn.init.constant_(m.bias, 0)\n                nn.init.constant_(m.weight, 1.0)\n\n        if isinstance(pretrained, str):\n            self.apply(_init_weights)\n            logger = get_root_logger()\n            load_checkpoint(self, pretrained, strict=False, logger=logger)\n        elif pretrained is None:\n            self.apply(_init_weights)\n        else:\n            raise TypeError('pretrained must be a str or None')\n\n    def get_num_layers(self):\n        return len(self.blocks)\n\n    @torch.jit.ignore\n    def no_weight_decay(self):\n        return {'pos_embed', 'cls_token'}\n\n    def forward_features(self, x):\n        B, C, H, W = x.shape\n        x, (Hp, Wp) = self.patch_embed(x)\n        batch_size, seq_len, _ = x.size()\n\n        cls_tokens = self.cls_token.expand(batch_size, -1, -1)  # stole cls_tokens impl from Phil Wang, thanks\n        x = torch.cat((cls_tokens, x), dim=1)\n        if self.pos_embed is not None:\n            '''\n            if self.use_abs_pos_emb:\n                x = x + self.pos_embed.expand(batch_size, -1, -1).type_as(x).to(x.device).clone().detach()\n            else:\n                x = x[:,1:] + self.pos_embed.expand(batch_size, -1, -1).type_as(x[:,1:]).to(x.device).clone().detach()\n                x = torch.cat([x[:,:1],x],dim=1)\n            '''\n            x = x + self.pos_embed\n        x = self.pos_drop(x)\n\n        rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None\n        features = []\n        for i, blk in enumerate(self.blocks):\n            if self.use_checkpoint:\n                x = checkpoint.checkpoint(blk, x, rel_pos_bias)\n            else:\n                x = blk(x, rel_pos_bias)\n            if i in self.out_indices:\n                xp = x[:, 1:, :].permute(0, 2, 1).reshape(B, -1, Hp, Wp)\n                features.append(xp.contiguous())\n\n        ops = [self.fpn1, self.fpn2, self.fpn3, self.fpn4]\n        for i in range(len(features)):\n            features[i] = ops[i](features[i])\n\n        return tuple(features)\n\n    def forward(self, x):\n        x = self.forward_features(x)\n        return x\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/configs_local/_base_/datasets/ade20k.py",
    "content": "# dataset settings\ndataset_type = 'ADE20KDataset'\ndata_root = 'data/ade/ADEChallengeData2016'\nimg_norm_cfg = dict(\n    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)\ncrop_size = (512, 512)\ntrain_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(type='LoadAnnotations', reduce_zero_label=True),\n    dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)),\n    dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),\n    dict(type='RandomFlip', prob=0.5),\n    dict(type='PhotoMetricDistortion'),\n    dict(type='Normalize', **img_norm_cfg),\n    dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),\n    dict(type='DefaultFormatBundle'),\n    dict(type='Collect', keys=['img', 'gt_semantic_seg']),\n]\ntest_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(\n        type='MultiScaleFlipAug',\n        img_scale=(2048, 512),\n        # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],\n        flip=False,\n        transforms=[\n            dict(type='Resize', keep_ratio=True),\n            dict(type='RandomFlip'),\n            dict(type='Normalize', **img_norm_cfg),\n            dict(type='ImageToTensor', keys=['img']),\n            dict(type='Collect', keys=['img']),\n        ])\n]\ndata = dict(\n    samples_per_gpu=4,\n    workers_per_gpu=4,\n    train=dict(\n        type=dataset_type,\n        data_root=data_root,\n        img_dir='images/training',\n        ann_dir='annotations/training',\n        pipeline=train_pipeline),\n    val=dict(\n        type=dataset_type,\n        data_root=data_root,\n        img_dir='images/validation',\n        ann_dir='annotations/validation',\n        pipeline=test_pipeline),\n    test=dict(\n        type=dataset_type,\n        data_root=data_root,\n        img_dir='images/validation',\n        ann_dir='annotations/validation',\n        pipeline=test_pipeline))\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/configs_local/_base_/datasets/ade20k_640x640.py",
    "content": "# dataset settings\ndataset_type = 'ADE20KDataset'\ndata_root = 'data/ade/ADEChallengeData2016'\nimg_norm_cfg = dict(\n    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)\ncrop_size = (640, 640)\ntrain_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(type='LoadAnnotations', reduce_zero_label=True),\n    dict(type='Resize', img_scale=(2560, 640), ratio_range=(0.5, 2.0)),\n    dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),\n    dict(type='RandomFlip', prob=0.5),\n    dict(type='PhotoMetricDistortion'),\n    dict(type='Normalize', **img_norm_cfg),\n    dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),\n    dict(type='DefaultFormatBundle'),\n    dict(type='Collect', keys=['img', 'gt_semantic_seg']),\n]\ntest_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(\n        type='MultiScaleFlipAug',\n        img_scale=(2560, 640),\n        # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],\n        flip=False,\n        transforms=[\n            dict(type='Resize', keep_ratio=True),\n            dict(type='RandomFlip'),\n            dict(type='Normalize', **img_norm_cfg),\n            dict(type='ImageToTensor', keys=['img']),\n            dict(type='Collect', keys=['img']),\n        ])\n]\ndata = dict(\n    samples_per_gpu=4,\n    workers_per_gpu=4,\n    train=dict(\n        type=dataset_type,\n        data_root=data_root,\n        img_dir='images/training',\n        ann_dir='annotations/training',\n        pipeline=train_pipeline),\n    val=dict(\n        type=dataset_type,\n        data_root=data_root,\n        img_dir='images/validation',\n        ann_dir='annotations/validation',\n        pipeline=test_pipeline),\n    test=dict(\n        type=dataset_type,\n        data_root=data_root,\n        img_dir='images/validation',\n        ann_dir='annotations/validation',\n        pipeline=test_pipeline))\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/configs_local/_base_/datasets/chase_db1.py",
    "content": "# dataset settings\ndataset_type = 'ChaseDB1Dataset'\ndata_root = 'data/CHASE_DB1'\nimg_norm_cfg = dict(\n    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)\nimg_scale = (960, 999)\ncrop_size = (128, 128)\ntrain_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(type='LoadAnnotations'),\n    dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)),\n    dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),\n    dict(type='RandomFlip', prob=0.5),\n    dict(type='PhotoMetricDistortion'),\n    dict(type='Normalize', **img_norm_cfg),\n    dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),\n    dict(type='DefaultFormatBundle'),\n    dict(type='Collect', keys=['img', 'gt_semantic_seg'])\n]\ntest_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(\n        type='MultiScaleFlipAug',\n        img_scale=img_scale,\n        # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0],\n        flip=False,\n        transforms=[\n            dict(type='Resize', keep_ratio=True),\n            dict(type='RandomFlip'),\n            dict(type='Normalize', **img_norm_cfg),\n            dict(type='ImageToTensor', keys=['img']),\n            dict(type='Collect', keys=['img'])\n        ])\n]\n\ndata = dict(\n    samples_per_gpu=4,\n    workers_per_gpu=4,\n    train=dict(\n        type='RepeatDataset',\n        times=40000,\n        dataset=dict(\n            type=dataset_type,\n            data_root=data_root,\n            img_dir='images/training',\n            ann_dir='annotations/training',\n            pipeline=train_pipeline)),\n    val=dict(\n        type=dataset_type,\n        data_root=data_root,\n        img_dir='images/validation',\n        ann_dir='annotations/validation',\n        pipeline=test_pipeline),\n    test=dict(\n        type=dataset_type,\n        data_root=data_root,\n        img_dir='images/validation',\n        ann_dir='annotations/validation',\n        pipeline=test_pipeline))\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/configs_local/_base_/datasets/cityscapes.py",
    "content": "# dataset settings\ndataset_type = 'CityscapesDataset'\ndata_root = 'data/cityscapes/'\nimg_norm_cfg = dict(\n    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)\ncrop_size = (512, 1024)\ntrain_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(type='LoadAnnotations'),\n    dict(type='Resize', img_scale=(2048, 1024), ratio_range=(0.5, 2.0)),\n    dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),\n    dict(type='RandomFlip', prob=0.5),\n    dict(type='PhotoMetricDistortion'),\n    dict(type='Normalize', **img_norm_cfg),\n    dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),\n    dict(type='DefaultFormatBundle'),\n    dict(type='Collect', keys=['img', 'gt_semantic_seg']),\n]\ntest_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(\n        type='MultiScaleFlipAug',\n        img_scale=(2048, 1024),\n        # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],\n        flip=False,\n        transforms=[\n            dict(type='Resize', keep_ratio=True),\n            dict(type='RandomFlip'),\n            dict(type='Normalize', **img_norm_cfg),\n            dict(type='ImageToTensor', keys=['img']),\n            dict(type='Collect', keys=['img']),\n        ])\n]\ndata = dict(\n    samples_per_gpu=2,\n    workers_per_gpu=2,\n    train=dict(\n        type=dataset_type,\n        data_root=data_root,\n        img_dir='leftImg8bit/train',\n        ann_dir='gtFine/train',\n        pipeline=train_pipeline),\n    val=dict(\n        type=dataset_type,\n        data_root=data_root,\n        img_dir='leftImg8bit/val',\n        ann_dir='gtFine/val',\n        pipeline=test_pipeline),\n    test=dict(\n        type=dataset_type,\n        data_root=data_root,\n        img_dir='leftImg8bit/val',\n        ann_dir='gtFine/val',\n        pipeline=test_pipeline))\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/configs_local/_base_/datasets/cityscapes_769x769.py",
    "content": "_base_ = './cityscapes.py'\nimg_norm_cfg = dict(\n    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)\ncrop_size = (769, 769)\ntrain_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(type='LoadAnnotations'),\n    dict(type='Resize', img_scale=(2049, 1025), ratio_range=(0.5, 2.0)),\n    dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),\n    dict(type='RandomFlip', prob=0.5),\n    dict(type='PhotoMetricDistortion'),\n    dict(type='Normalize', **img_norm_cfg),\n    dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),\n    dict(type='DefaultFormatBundle'),\n    dict(type='Collect', keys=['img', 'gt_semantic_seg']),\n]\ntest_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(\n        type='MultiScaleFlipAug',\n        img_scale=(2049, 1025),\n        # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],\n        flip=False,\n        transforms=[\n            dict(type='Resize', keep_ratio=True),\n            dict(type='RandomFlip'),\n            dict(type='Normalize', **img_norm_cfg),\n            dict(type='ImageToTensor', keys=['img']),\n            dict(type='Collect', keys=['img']),\n        ])\n]\ndata = dict(\n    train=dict(pipeline=train_pipeline),\n    val=dict(pipeline=test_pipeline),\n    test=dict(pipeline=test_pipeline))\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/configs_local/_base_/datasets/coco-stuff10k.py",
    "content": "# dataset settings\ndataset_type = 'COCOStuffDataset'\ndata_root = 'data/coco_stuff10k'\nimg_norm_cfg = dict(\n    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)\ncrop_size = (512, 512)\ntrain_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(type='LoadAnnotations', reduce_zero_label=True),\n    dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)),\n    dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),\n    dict(type='RandomFlip', prob=0.5),\n    dict(type='PhotoMetricDistortion'),\n    dict(type='Normalize', **img_norm_cfg),\n    dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),\n    dict(type='DefaultFormatBundle'),\n    dict(type='Collect', keys=['img', 'gt_semantic_seg']),\n]\ntest_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(\n        type='MultiScaleFlipAug',\n        img_scale=(2048, 512),\n        # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],\n        flip=False,\n        transforms=[\n            dict(type='Resize', keep_ratio=True),\n            dict(type='RandomFlip'),\n            dict(type='Normalize', **img_norm_cfg),\n            dict(type='ImageToTensor', keys=['img']),\n            dict(type='Collect', keys=['img']),\n        ])\n]\ndata = dict(\n    samples_per_gpu=4,\n    workers_per_gpu=4,\n    train=dict(\n        type=dataset_type,\n        data_root=data_root,\n        reduce_zero_label=True,\n        img_dir='images/train2014',\n        ann_dir='annotations/train2014',\n        pipeline=train_pipeline),\n    val=dict(\n        type=dataset_type,\n        data_root=data_root,\n        reduce_zero_label=True,\n        img_dir='images/test2014',\n        ann_dir='annotations/test2014',\n        pipeline=test_pipeline),\n    test=dict(\n        type=dataset_type,\n        data_root=data_root,\n        reduce_zero_label=True,\n        img_dir='images/test2014',\n        ann_dir='annotations/test2014',\n        pipeline=test_pipeline))"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/configs_local/_base_/datasets/drive.py",
    "content": "# dataset settings\ndataset_type = 'DRIVEDataset'\ndata_root = 'data/DRIVE'\nimg_norm_cfg = dict(\n    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)\nimg_scale = (584, 565)\ncrop_size = (64, 64)\ntrain_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(type='LoadAnnotations'),\n    dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)),\n    dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),\n    dict(type='RandomFlip', prob=0.5),\n    dict(type='PhotoMetricDistortion'),\n    dict(type='Normalize', **img_norm_cfg),\n    dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),\n    dict(type='DefaultFormatBundle'),\n    dict(type='Collect', keys=['img', 'gt_semantic_seg'])\n]\ntest_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(\n        type='MultiScaleFlipAug',\n        img_scale=img_scale,\n        # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0],\n        flip=False,\n        transforms=[\n            dict(type='Resize', keep_ratio=True),\n            dict(type='RandomFlip'),\n            dict(type='Normalize', **img_norm_cfg),\n            dict(type='ImageToTensor', keys=['img']),\n            dict(type='Collect', keys=['img'])\n        ])\n]\n\ndata = dict(\n    samples_per_gpu=4,\n    workers_per_gpu=4,\n    train=dict(\n        type='RepeatDataset',\n        times=40000,\n        dataset=dict(\n            type=dataset_type,\n            data_root=data_root,\n            img_dir='images/training',\n            ann_dir='annotations/training',\n            pipeline=train_pipeline)),\n    val=dict(\n        type=dataset_type,\n        data_root=data_root,\n        img_dir='images/validation',\n        ann_dir='annotations/validation',\n        pipeline=test_pipeline),\n    test=dict(\n        type=dataset_type,\n        data_root=data_root,\n        img_dir='images/validation',\n        ann_dir='annotations/validation',\n        pipeline=test_pipeline))\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/configs_local/_base_/datasets/hrf.py",
    "content": "# dataset settings\ndataset_type = 'HRFDataset'\ndata_root = 'data/HRF'\nimg_norm_cfg = dict(\n    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)\nimg_scale = (2336, 3504)\ncrop_size = (256, 256)\ntrain_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(type='LoadAnnotations'),\n    dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)),\n    dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),\n    dict(type='RandomFlip', prob=0.5),\n    dict(type='PhotoMetricDistortion'),\n    dict(type='Normalize', **img_norm_cfg),\n    dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),\n    dict(type='DefaultFormatBundle'),\n    dict(type='Collect', keys=['img', 'gt_semantic_seg'])\n]\ntest_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(\n        type='MultiScaleFlipAug',\n        img_scale=img_scale,\n        # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0],\n        flip=False,\n        transforms=[\n            dict(type='Resize', keep_ratio=True),\n            dict(type='RandomFlip'),\n            dict(type='Normalize', **img_norm_cfg),\n            dict(type='ImageToTensor', keys=['img']),\n            dict(type='Collect', keys=['img'])\n        ])\n]\n\ndata = dict(\n    samples_per_gpu=4,\n    workers_per_gpu=4,\n    train=dict(\n        type='RepeatDataset',\n        times=40000,\n        dataset=dict(\n            type=dataset_type,\n            data_root=data_root,\n            img_dir='images/training',\n            ann_dir='annotations/training',\n            pipeline=train_pipeline)),\n    val=dict(\n        type=dataset_type,\n        data_root=data_root,\n        img_dir='images/validation',\n        ann_dir='annotations/validation',\n        pipeline=test_pipeline),\n    test=dict(\n        type=dataset_type,\n        data_root=data_root,\n        img_dir='images/validation',\n        ann_dir='annotations/validation',\n        pipeline=test_pipeline))\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/configs_local/_base_/datasets/pascal_context.py",
    "content": "# dataset settings\ndataset_type = 'PascalContextDataset'\ndata_root = 'data/VOCdevkit/VOC2010/'\nimg_norm_cfg = dict(\n    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)\n\nimg_scale = (520, 520)\ncrop_size = (480, 480)\n\ntrain_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(type='LoadAnnotations'),\n    dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)),\n    dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),\n    dict(type='RandomFlip', prob=0.5),\n    dict(type='PhotoMetricDistortion'),\n    dict(type='Normalize', **img_norm_cfg),\n    dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),\n    dict(type='DefaultFormatBundle'),\n    dict(type='Collect', keys=['img', 'gt_semantic_seg']),\n]\ntest_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(\n        type='MultiScaleFlipAug',\n        img_scale=img_scale,\n        # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],\n        flip=False,\n        transforms=[\n            dict(type='Resize', keep_ratio=True),\n            dict(type='RandomFlip'),\n            dict(type='Normalize', **img_norm_cfg),\n            dict(type='ImageToTensor', keys=['img']),\n            dict(type='Collect', keys=['img']),\n        ])\n]\ndata = dict(\n    samples_per_gpu=4,\n    workers_per_gpu=4,\n    train=dict(\n        type=dataset_type,\n        data_root=data_root,\n        img_dir='JPEGImages',\n        ann_dir='SegmentationClassContext',\n        split='ImageSets/SegmentationContext/train.txt',\n        pipeline=train_pipeline),\n    val=dict(\n        type=dataset_type,\n        data_root=data_root,\n        img_dir='JPEGImages',\n        ann_dir='SegmentationClassContext',\n        split='ImageSets/SegmentationContext/val.txt',\n        pipeline=test_pipeline),\n    test=dict(\n        type=dataset_type,\n        data_root=data_root,\n        img_dir='JPEGImages',\n        ann_dir='SegmentationClassContext',\n        split='ImageSets/SegmentationContext/val.txt',\n        pipeline=test_pipeline))\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/configs_local/_base_/datasets/pascal_voc12.py",
    "content": "# dataset settings\ndataset_type = 'PascalVOCDataset'\ndata_root = 'data/VOCdevkit/VOC2012'\nimg_norm_cfg = dict(\n    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)\ncrop_size = (512, 512)\ntrain_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(type='LoadAnnotations'),\n    dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)),\n    dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),\n    dict(type='RandomFlip', prob=0.5),\n    dict(type='PhotoMetricDistortion'),\n    dict(type='Normalize', **img_norm_cfg),\n    dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),\n    dict(type='DefaultFormatBundle'),\n    dict(type='Collect', keys=['img', 'gt_semantic_seg']),\n]\ntest_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(\n        type='MultiScaleFlipAug',\n        img_scale=(2048, 512),\n        # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],\n        flip=False,\n        transforms=[\n            dict(type='Resize', keep_ratio=True),\n            dict(type='RandomFlip'),\n            dict(type='Normalize', **img_norm_cfg),\n            dict(type='ImageToTensor', keys=['img']),\n            dict(type='Collect', keys=['img']),\n        ])\n]\ndata = dict(\n    samples_per_gpu=4,\n    workers_per_gpu=4,\n    train=dict(\n        type=dataset_type,\n        data_root=data_root,\n        img_dir='JPEGImages',\n        ann_dir='SegmentationClass',\n        split='ImageSets/Segmentation/train.txt',\n        pipeline=train_pipeline),\n    val=dict(\n        type=dataset_type,\n        data_root=data_root,\n        img_dir='JPEGImages',\n        ann_dir='SegmentationClass',\n        split='ImageSets/Segmentation/val.txt',\n        pipeline=test_pipeline),\n    test=dict(\n        type=dataset_type,\n        data_root=data_root,\n        img_dir='JPEGImages',\n        ann_dir='SegmentationClass',\n        split='ImageSets/Segmentation/val.txt',\n        pipeline=test_pipeline))\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/configs_local/_base_/datasets/pascal_voc12_aug.py",
    "content": "_base_ = './pascal_voc12.py'\n# dataset settings\ndata = dict(\n    train=dict(\n        ann_dir=['SegmentationClass', 'SegmentationClassAug'],\n        split=[\n            'ImageSets/Segmentation/train.txt',\n            'ImageSets/Segmentation/aug.txt'\n        ]))\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/configs_local/_base_/datasets/stare.py",
    "content": "# dataset settings\ndataset_type = 'STAREDataset'\ndata_root = 'data/STARE'\nimg_norm_cfg = dict(\n    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)\nimg_scale = (605, 700)\ncrop_size = (128, 128)\ntrain_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(type='LoadAnnotations'),\n    dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)),\n    dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),\n    dict(type='RandomFlip', prob=0.5),\n    dict(type='PhotoMetricDistortion'),\n    dict(type='Normalize', **img_norm_cfg),\n    dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),\n    dict(type='DefaultFormatBundle'),\n    dict(type='Collect', keys=['img', 'gt_semantic_seg'])\n]\ntest_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(\n        type='MultiScaleFlipAug',\n        img_scale=img_scale,\n        # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0],\n        flip=False,\n        transforms=[\n            dict(type='Resize', keep_ratio=True),\n            dict(type='RandomFlip'),\n            dict(type='Normalize', **img_norm_cfg),\n            dict(type='ImageToTensor', keys=['img']),\n            dict(type='Collect', keys=['img'])\n        ])\n]\n\ndata = dict(\n    samples_per_gpu=4,\n    workers_per_gpu=4,\n    train=dict(\n        type='RepeatDataset',\n        times=40000,\n        dataset=dict(\n            type=dataset_type,\n            data_root=data_root,\n            img_dir='images/training',\n            ann_dir='annotations/training',\n            pipeline=train_pipeline)),\n    val=dict(\n        type=dataset_type,\n        data_root=data_root,\n        img_dir='images/validation',\n        ann_dir='annotations/validation',\n        pipeline=test_pipeline),\n    test=dict(\n        type=dataset_type,\n        data_root=data_root,\n        img_dir='images/validation',\n        ann_dir='annotations/validation',\n        pipeline=test_pipeline))\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/configs_local/_base_/default_runtime.py",
    "content": "# yapf:disable\nlog_config = dict(\n    interval=50,\n    hooks=[\n        dict(type='TextLoggerHook', by_epoch=False),\n        # dict(type='TensorboardLoggerHook')\n    ])\n# yapf:enable\ndist_params = dict(backend='nccl')\nlog_level = 'INFO'\nload_from = None\nresume_from = None\nworkflow = [('train', 1)]\ncudnn_benchmark = True\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/configs_local/_base_/models/ann_r50-d8.py",
    "content": "# model settings\nnorm_cfg = dict(type='SyncBN', requires_grad=True)\nmodel = dict(\n    type='EncoderDecoder',\n    pretrained='open-mmlab://resnet50_v1c',\n    backbone=dict(\n        type='ResNetV1c',\n        depth=50,\n        num_stages=4,\n        out_indices=(0, 1, 2, 3),\n        dilations=(1, 1, 2, 4),\n        strides=(1, 2, 1, 1),\n        norm_cfg=norm_cfg,\n        norm_eval=False,\n        style='pytorch',\n        contract_dilation=True),\n    decode_head=dict(\n        type='ANNHead',\n        in_channels=[1024, 2048],\n        in_index=[2, 3],\n        channels=512,\n        project_channels=256,\n        query_scales=(1, ),\n        key_pool_scales=(1, 3, 6, 8),\n        dropout_ratio=0.1,\n        num_classes=19,\n        norm_cfg=norm_cfg,\n        align_corners=False,\n        loss_decode=dict(\n            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),\n    auxiliary_head=dict(\n        type='FCNHead',\n        in_channels=1024,\n        in_index=2,\n        channels=256,\n        num_convs=1,\n        concat_input=False,\n        dropout_ratio=0.1,\n        num_classes=19,\n        norm_cfg=norm_cfg,\n        align_corners=False,\n        loss_decode=dict(\n            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),\n    # model training and testing settings\n    train_cfg=dict(),\n    test_cfg=dict(mode='whole'))\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/configs_local/_base_/models/apcnet_r50-d8.py",
    "content": "# model settings\nnorm_cfg = dict(type='SyncBN', requires_grad=True)\nmodel = dict(\n    type='EncoderDecoder',\n    pretrained='open-mmlab://resnet50_v1c',\n    backbone=dict(\n        type='ResNetV1c',\n        depth=50,\n        num_stages=4,\n        out_indices=(0, 1, 2, 3),\n        dilations=(1, 1, 2, 4),\n        strides=(1, 2, 1, 1),\n        norm_cfg=norm_cfg,\n        norm_eval=False,\n        style='pytorch',\n        contract_dilation=True),\n    decode_head=dict(\n        type='APCHead',\n        in_channels=2048,\n        in_index=3,\n        channels=512,\n        pool_scales=(1, 2, 3, 6),\n        dropout_ratio=0.1,\n        num_classes=19,\n        norm_cfg=dict(type='SyncBN', requires_grad=True),\n        align_corners=False,\n        loss_decode=dict(\n            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),\n    auxiliary_head=dict(\n        type='FCNHead',\n        in_channels=1024,\n        in_index=2,\n        channels=256,\n        num_convs=1,\n        concat_input=False,\n        dropout_ratio=0.1,\n        num_classes=19,\n        norm_cfg=norm_cfg,\n        align_corners=False,\n        loss_decode=dict(\n            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),\n    # model training and testing settings\n    train_cfg=dict(),\n    test_cfg=dict(mode='whole'))\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/configs_local/_base_/models/ccnet_r50-d8.py",
    "content": "# model settings\nnorm_cfg = dict(type='SyncBN', requires_grad=True)\nmodel = dict(\n    type='EncoderDecoder',\n    pretrained='open-mmlab://resnet50_v1c',\n    backbone=dict(\n        type='ResNetV1c',\n        depth=50,\n        num_stages=4,\n        out_indices=(0, 1, 2, 3),\n        dilations=(1, 1, 2, 4),\n        strides=(1, 2, 1, 1),\n        norm_cfg=norm_cfg,\n        norm_eval=False,\n        style='pytorch',\n        contract_dilation=True),\n    decode_head=dict(\n        type='CCHead',\n        in_channels=2048,\n        in_index=3,\n        channels=512,\n        recurrence=2,\n        dropout_ratio=0.1,\n        num_classes=19,\n        norm_cfg=norm_cfg,\n        align_corners=False,\n        loss_decode=dict(\n            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),\n    auxiliary_head=dict(\n        type='FCNHead',\n        in_channels=1024,\n        in_index=2,\n        channels=256,\n        num_convs=1,\n        concat_input=False,\n        dropout_ratio=0.1,\n        num_classes=19,\n        norm_cfg=norm_cfg,\n        align_corners=False,\n        loss_decode=dict(\n            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),\n    # model training and testing settings\n    train_cfg=dict(),\n    test_cfg=dict(mode='whole'))\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/configs_local/_base_/models/cgnet.py",
    "content": "# model settings\nnorm_cfg = dict(type='SyncBN', eps=1e-03, requires_grad=True)\nmodel = dict(\n    type='EncoderDecoder',\n    backbone=dict(\n        type='CGNet',\n        norm_cfg=norm_cfg,\n        in_channels=3,\n        num_channels=(32, 64, 128),\n        num_blocks=(3, 21),\n        dilations=(2, 4),\n        reductions=(8, 16)),\n    decode_head=dict(\n        type='FCNHead',\n        in_channels=256,\n        in_index=2,\n        channels=256,\n        num_convs=0,\n        concat_input=False,\n        dropout_ratio=0,\n        num_classes=19,\n        norm_cfg=norm_cfg,\n        loss_decode=dict(\n            type='CrossEntropyLoss',\n            use_sigmoid=False,\n            loss_weight=1.0,\n            class_weight=[\n                2.5959933, 6.7415504, 3.5354059, 9.8663225, 9.690899, 9.369352,\n                10.289121, 9.953208, 4.3097677, 9.490387, 7.674431, 9.396905,\n                10.347791, 6.3927646, 10.226669, 10.241062, 10.280587,\n                10.396974, 10.055647\n            ])),\n    # model training and testing settings\n    train_cfg=dict(sampler=None),\n    test_cfg=dict(mode='whole'))\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/configs_local/_base_/models/danet_r50-d8.py",
    "content": "# model settings\nnorm_cfg = dict(type='SyncBN', requires_grad=True)\nmodel = dict(\n    type='EncoderDecoder',\n    pretrained='open-mmlab://resnet50_v1c',\n    backbone=dict(\n        type='ResNetV1c',\n        depth=50,\n        num_stages=4,\n        out_indices=(0, 1, 2, 3),\n        dilations=(1, 1, 2, 4),\n        strides=(1, 2, 1, 1),\n        norm_cfg=norm_cfg,\n        norm_eval=False,\n        style='pytorch',\n        contract_dilation=True),\n    decode_head=dict(\n        type='DAHead',\n        in_channels=2048,\n        in_index=3,\n        channels=512,\n        pam_channels=64,\n        dropout_ratio=0.1,\n        num_classes=19,\n        norm_cfg=norm_cfg,\n        align_corners=False,\n        loss_decode=dict(\n            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),\n    auxiliary_head=dict(\n        type='FCNHead',\n        in_channels=1024,\n        in_index=2,\n        channels=256,\n        num_convs=1,\n        concat_input=False,\n        dropout_ratio=0.1,\n        num_classes=19,\n        norm_cfg=norm_cfg,\n        align_corners=False,\n        loss_decode=dict(\n            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),\n    # model training and testing settings\n    train_cfg=dict(),\n    test_cfg=dict(mode='whole'))\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/configs_local/_base_/models/deeplabv3_r50-d8.py",
    "content": "# model settings\nnorm_cfg = dict(type='SyncBN', requires_grad=True)\nmodel = dict(\n    type='EncoderDecoder',\n    pretrained='open-mmlab://resnet50_v1c',\n    backbone=dict(\n        type='ResNetV1c',\n        depth=50,\n        num_stages=4,\n        out_indices=(0, 1, 2, 3),\n        dilations=(1, 1, 2, 4),\n        strides=(1, 2, 1, 1),\n        norm_cfg=norm_cfg,\n        norm_eval=False,\n        style='pytorch',\n        contract_dilation=True),\n    decode_head=dict(\n        type='ASPPHead',\n        in_channels=2048,\n        in_index=3,\n        channels=512,\n        dilations=(1, 12, 24, 36),\n        dropout_ratio=0.1,\n        num_classes=19,\n        norm_cfg=norm_cfg,\n        align_corners=False,\n        loss_decode=dict(\n            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),\n    auxiliary_head=dict(\n        type='FCNHead',\n        in_channels=1024,\n        in_index=2,\n        channels=256,\n        num_convs=1,\n        concat_input=False,\n        dropout_ratio=0.1,\n        num_classes=19,\n        norm_cfg=norm_cfg,\n        align_corners=False,\n        loss_decode=dict(\n            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),\n    # model training and testing settings\n    train_cfg=dict(),\n    test_cfg=dict(mode='whole'))\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/configs_local/_base_/models/deeplabv3_unet_s5-d16.py",
    "content": "# model settings\nnorm_cfg = dict(type='SyncBN', requires_grad=True)\nmodel = dict(\n    type='EncoderDecoder',\n    pretrained=None,\n    backbone=dict(\n        type='UNet',\n        in_channels=3,\n        base_channels=64,\n        num_stages=5,\n        strides=(1, 1, 1, 1, 1),\n        enc_num_convs=(2, 2, 2, 2, 2),\n        dec_num_convs=(2, 2, 2, 2),\n        downsamples=(True, True, True, True),\n        enc_dilations=(1, 1, 1, 1, 1),\n        dec_dilations=(1, 1, 1, 1),\n        with_cp=False,\n        conv_cfg=None,\n        norm_cfg=norm_cfg,\n        act_cfg=dict(type='ReLU'),\n        upsample_cfg=dict(type='InterpConv'),\n        norm_eval=False),\n    decode_head=dict(\n        type='ASPPHead',\n        in_channels=64,\n        in_index=4,\n        channels=16,\n        dilations=(1, 12, 24, 36),\n        dropout_ratio=0.1,\n        num_classes=2,\n        norm_cfg=norm_cfg,\n        align_corners=False,\n        loss_decode=dict(\n            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),\n    auxiliary_head=dict(\n        type='FCNHead',\n        in_channels=128,\n        in_index=3,\n        channels=64,\n        num_convs=1,\n        concat_input=False,\n        dropout_ratio=0.1,\n        num_classes=2,\n        norm_cfg=norm_cfg,\n        align_corners=False,\n        loss_decode=dict(\n            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),\n    # model training and testing settings\n    train_cfg=dict(),\n    test_cfg=dict(mode='slide', crop_size=256, stride=170))\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/configs_local/_base_/models/deeplabv3plus_r50-d8.py",
    "content": "# model settings\nnorm_cfg = dict(type='SyncBN', requires_grad=True)\nmodel = dict(\n    type='EncoderDecoder',\n    pretrained='open-mmlab://resnet50_v1c',\n    backbone=dict(\n        type='ResNetV1c',\n        depth=50,\n        num_stages=4,\n        out_indices=(0, 1, 2, 3),\n        dilations=(1, 1, 2, 4),\n        strides=(1, 2, 1, 1),\n        norm_cfg=norm_cfg,\n        norm_eval=False,\n        style='pytorch',\n        contract_dilation=True),\n    decode_head=dict(\n        type='DepthwiseSeparableASPPHead',\n        in_channels=2048,\n        in_index=3,\n        channels=512,\n        dilations=(1, 12, 24, 36),\n        c1_in_channels=256,\n        c1_channels=48,\n        dropout_ratio=0.1,\n        num_classes=19,\n        norm_cfg=norm_cfg,\n        align_corners=False,\n        loss_decode=dict(\n            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),\n    auxiliary_head=dict(\n        type='FCNHead',\n        in_channels=1024,\n        in_index=2,\n        channels=256,\n        num_convs=1,\n        concat_input=False,\n        dropout_ratio=0.1,\n        num_classes=19,\n        norm_cfg=norm_cfg,\n        align_corners=False,\n        loss_decode=dict(\n            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),\n    # model training and testing settings\n    train_cfg=dict(),\n    test_cfg=dict(mode='whole'))\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/configs_local/_base_/models/dmnet_r50-d8.py",
    "content": "# model settings\nnorm_cfg = dict(type='SyncBN', requires_grad=True)\nmodel = dict(\n    type='EncoderDecoder',\n    pretrained='open-mmlab://resnet50_v1c',\n    backbone=dict(\n        type='ResNetV1c',\n        depth=50,\n        num_stages=4,\n        out_indices=(0, 1, 2, 3),\n        dilations=(1, 1, 2, 4),\n        strides=(1, 2, 1, 1),\n        norm_cfg=norm_cfg,\n        norm_eval=False,\n        style='pytorch',\n        contract_dilation=True),\n    decode_head=dict(\n        type='DMHead',\n        in_channels=2048,\n        in_index=3,\n        channels=512,\n        filter_sizes=(1, 3, 5, 7),\n        dropout_ratio=0.1,\n        num_classes=19,\n        norm_cfg=dict(type='SyncBN', requires_grad=True),\n        align_corners=False,\n        loss_decode=dict(\n            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),\n    auxiliary_head=dict(\n        type='FCNHead',\n        in_channels=1024,\n        in_index=2,\n        channels=256,\n        num_convs=1,\n        concat_input=False,\n        dropout_ratio=0.1,\n        num_classes=19,\n        norm_cfg=norm_cfg,\n        align_corners=False,\n        loss_decode=dict(\n            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),\n    # model training and testing settings\n    train_cfg=dict(),\n    test_cfg=dict(mode='whole'))\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/configs_local/_base_/models/dnl_r50-d8.py",
    "content": "# model settings\nnorm_cfg = dict(type='SyncBN', requires_grad=True)\nmodel = dict(\n    type='EncoderDecoder',\n    pretrained='open-mmlab://resnet50_v1c',\n    backbone=dict(\n        type='ResNetV1c',\n        depth=50,\n        num_stages=4,\n        out_indices=(0, 1, 2, 3),\n        dilations=(1, 1, 2, 4),\n        strides=(1, 2, 1, 1),\n        norm_cfg=norm_cfg,\n        norm_eval=False,\n        style='pytorch',\n        contract_dilation=True),\n    decode_head=dict(\n        type='DNLHead',\n        in_channels=2048,\n        in_index=3,\n        channels=512,\n        dropout_ratio=0.1,\n        reduction=2,\n        use_scale=True,\n        mode='embedded_gaussian',\n        num_classes=19,\n        norm_cfg=norm_cfg,\n        align_corners=False,\n        loss_decode=dict(\n            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),\n    auxiliary_head=dict(\n        type='FCNHead',\n        in_channels=1024,\n        in_index=2,\n        channels=256,\n        num_convs=1,\n        concat_input=False,\n        dropout_ratio=0.1,\n        num_classes=19,\n        norm_cfg=norm_cfg,\n        align_corners=False,\n        loss_decode=dict(\n            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),\n    # model training and testing settings\n    train_cfg=dict(),\n    test_cfg=dict(mode='whole'))\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/configs_local/_base_/models/emanet_r50-d8.py",
    "content": "# model settings\nnorm_cfg = dict(type='SyncBN', requires_grad=True)\nmodel = dict(\n    type='EncoderDecoder',\n    pretrained='open-mmlab://resnet50_v1c',\n    backbone=dict(\n        type='ResNetV1c',\n        depth=50,\n        num_stages=4,\n        out_indices=(0, 1, 2, 3),\n        dilations=(1, 1, 2, 4),\n        strides=(1, 2, 1, 1),\n        norm_cfg=norm_cfg,\n        norm_eval=False,\n        style='pytorch',\n        contract_dilation=True),\n    decode_head=dict(\n        type='EMAHead',\n        in_channels=2048,\n        in_index=3,\n        channels=256,\n        ema_channels=512,\n        num_bases=64,\n        num_stages=3,\n        momentum=0.1,\n        dropout_ratio=0.1,\n        num_classes=19,\n        norm_cfg=norm_cfg,\n        align_corners=False,\n        loss_decode=dict(\n            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),\n    auxiliary_head=dict(\n        type='FCNHead',\n        in_channels=1024,\n        in_index=2,\n        channels=256,\n        num_convs=1,\n        concat_input=False,\n        dropout_ratio=0.1,\n        num_classes=19,\n        norm_cfg=norm_cfg,\n        align_corners=False,\n        loss_decode=dict(\n            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),\n    # model training and testing settings\n    train_cfg=dict(),\n    test_cfg=dict(mode='whole'))\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/configs_local/_base_/models/encnet_r50-d8.py",
    "content": "# model settings\nnorm_cfg = dict(type='SyncBN', requires_grad=True)\nmodel = dict(\n    type='EncoderDecoder',\n    pretrained='open-mmlab://resnet50_v1c',\n    backbone=dict(\n        type='ResNetV1c',\n        depth=50,\n        num_stages=4,\n        out_indices=(0, 1, 2, 3),\n        dilations=(1, 1, 2, 4),\n        strides=(1, 2, 1, 1),\n        norm_cfg=norm_cfg,\n        norm_eval=False,\n        style='pytorch',\n        contract_dilation=True),\n    decode_head=dict(\n        type='EncHead',\n        in_channels=[512, 1024, 2048],\n        in_index=(1, 2, 3),\n        channels=512,\n        num_codes=32,\n        use_se_loss=True,\n        add_lateral=False,\n        dropout_ratio=0.1,\n        num_classes=19,\n        norm_cfg=norm_cfg,\n        align_corners=False,\n        loss_decode=dict(\n            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),\n        loss_se_decode=dict(\n            type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.2)),\n    auxiliary_head=dict(\n        type='FCNHead',\n        in_channels=1024,\n        in_index=2,\n        channels=256,\n        num_convs=1,\n        concat_input=False,\n        dropout_ratio=0.1,\n        num_classes=19,\n        norm_cfg=norm_cfg,\n        align_corners=False,\n        loss_decode=dict(\n            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),\n    # model training and testing settings\n    train_cfg=dict(),\n    test_cfg=dict(mode='whole'))\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/configs_local/_base_/models/fast_scnn.py",
    "content": "# model settings\nnorm_cfg = dict(type='SyncBN', requires_grad=True, momentum=0.01)\nmodel = dict(\n    type='EncoderDecoder',\n    backbone=dict(\n        type='FastSCNN',\n        downsample_dw_channels=(32, 48),\n        global_in_channels=64,\n        global_block_channels=(64, 96, 128),\n        global_block_strides=(2, 2, 1),\n        global_out_channels=128,\n        higher_in_channels=64,\n        lower_in_channels=128,\n        fusion_out_channels=128,\n        out_indices=(0, 1, 2),\n        norm_cfg=norm_cfg,\n        align_corners=False),\n    decode_head=dict(\n        type='DepthwiseSeparableFCNHead',\n        in_channels=128,\n        channels=128,\n        concat_input=False,\n        num_classes=19,\n        in_index=-1,\n        norm_cfg=norm_cfg,\n        align_corners=False,\n        loss_decode=dict(\n            type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.4)),\n    auxiliary_head=[\n        dict(\n            type='FCNHead',\n            in_channels=128,\n            channels=32,\n            num_convs=1,\n            num_classes=19,\n            in_index=-2,\n            norm_cfg=norm_cfg,\n            concat_input=False,\n            align_corners=False,\n            loss_decode=dict(\n                type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.4)),\n        dict(\n            type='FCNHead',\n            in_channels=64,\n            channels=32,\n            num_convs=1,\n            num_classes=19,\n            in_index=-3,\n            norm_cfg=norm_cfg,\n            concat_input=False,\n            align_corners=False,\n            loss_decode=dict(\n                type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.4)),\n    ],\n    # model training and testing settings\n    train_cfg=dict(),\n    test_cfg=dict(mode='whole'))\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/configs_local/_base_/models/fcn_hr18.py",
    "content": "# model settings\nnorm_cfg = dict(type='SyncBN', requires_grad=True)\nmodel = dict(\n    type='EncoderDecoder',\n    pretrained='open-mmlab://msra/hrnetv2_w18',\n    backbone=dict(\n        type='HRNet',\n        norm_cfg=norm_cfg,\n        norm_eval=False,\n        extra=dict(\n            stage1=dict(\n                num_modules=1,\n                num_branches=1,\n                block='BOTTLENECK',\n                num_blocks=(4, ),\n                num_channels=(64, )),\n            stage2=dict(\n                num_modules=1,\n                num_branches=2,\n                block='BASIC',\n                num_blocks=(4, 4),\n                num_channels=(18, 36)),\n            stage3=dict(\n                num_modules=4,\n                num_branches=3,\n                block='BASIC',\n                num_blocks=(4, 4, 4),\n                num_channels=(18, 36, 72)),\n            stage4=dict(\n                num_modules=3,\n                num_branches=4,\n                block='BASIC',\n                num_blocks=(4, 4, 4, 4),\n                num_channels=(18, 36, 72, 144)))),\n    decode_head=dict(\n        type='FCNHead',\n        in_channels=[18, 36, 72, 144],\n        in_index=(0, 1, 2, 3),\n        channels=sum([18, 36, 72, 144]),\n        input_transform='resize_concat',\n        kernel_size=1,\n        num_convs=1,\n        concat_input=False,\n        dropout_ratio=-1,\n        num_classes=19,\n        norm_cfg=norm_cfg,\n        align_corners=False,\n        loss_decode=dict(\n            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),\n    # model training and testing settings\n    train_cfg=dict(),\n    test_cfg=dict(mode='whole'))\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/configs_local/_base_/models/fcn_r50-d8.py",
    "content": "# model settings\nnorm_cfg = dict(type='SyncBN', requires_grad=True)\nmodel = dict(\n    type='EncoderDecoder',\n    pretrained='open-mmlab://resnet50_v1c',\n    backbone=dict(\n        type='ResNetV1c',\n        depth=50,\n        num_stages=4,\n        out_indices=(0, 1, 2, 3),\n        dilations=(1, 1, 2, 4),\n        strides=(1, 2, 1, 1),\n        norm_cfg=norm_cfg,\n        norm_eval=False,\n        style='pytorch',\n        contract_dilation=True),\n    decode_head=dict(\n        type='FCNHead',\n        in_channels=2048,\n        in_index=3,\n        channels=512,\n        num_convs=2,\n        concat_input=True,\n        dropout_ratio=0.1,\n        num_classes=19,\n        norm_cfg=norm_cfg,\n        align_corners=False,\n        loss_decode=dict(\n            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),\n    auxiliary_head=dict(\n        type='FCNHead',\n        in_channels=1024,\n        in_index=2,\n        channels=256,\n        num_convs=1,\n        concat_input=False,\n        dropout_ratio=0.1,\n        num_classes=19,\n        norm_cfg=norm_cfg,\n        align_corners=False,\n        loss_decode=dict(\n            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),\n    # model training and testing settings\n    train_cfg=dict(),\n    test_cfg=dict(mode='whole'))\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/configs_local/_base_/models/fcn_unet_s5-d16.py",
    "content": "# model settings\nnorm_cfg = dict(type='SyncBN', requires_grad=True)\nmodel = dict(\n    type='EncoderDecoder',\n    pretrained=None,\n    backbone=dict(\n        type='UNet',\n        in_channels=3,\n        base_channels=64,\n        num_stages=5,\n        strides=(1, 1, 1, 1, 1),\n        enc_num_convs=(2, 2, 2, 2, 2),\n        dec_num_convs=(2, 2, 2, 2),\n        downsamples=(True, True, True, True),\n        enc_dilations=(1, 1, 1, 1, 1),\n        dec_dilations=(1, 1, 1, 1),\n        with_cp=False,\n        conv_cfg=None,\n        norm_cfg=norm_cfg,\n        act_cfg=dict(type='ReLU'),\n        upsample_cfg=dict(type='InterpConv'),\n        norm_eval=False),\n    decode_head=dict(\n        type='FCNHead',\n        in_channels=64,\n        in_index=4,\n        channels=64,\n        num_convs=1,\n        concat_input=False,\n        dropout_ratio=0.1,\n        num_classes=2,\n        norm_cfg=norm_cfg,\n        align_corners=False,\n        loss_decode=dict(\n            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),\n    auxiliary_head=dict(\n        type='FCNHead',\n        in_channels=128,\n        in_index=3,\n        channels=64,\n        num_convs=1,\n        concat_input=False,\n        dropout_ratio=0.1,\n        num_classes=2,\n        norm_cfg=norm_cfg,\n        align_corners=False,\n        loss_decode=dict(\n            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),\n    # model training and testing settings\n    train_cfg=dict(),\n    test_cfg=dict(mode='slide', crop_size=256, stride=170))\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/configs_local/_base_/models/fpn_r50.py",
    "content": "# model settings\nnorm_cfg = dict(type='SyncBN', requires_grad=True)\nmodel = dict(\n    type='EncoderDecoder',\n    pretrained='open-mmlab://resnet50_v1c',\n    backbone=dict(\n        type='ResNetV1c',\n        depth=50,\n        num_stages=4,\n        out_indices=(0, 1, 2, 3),\n        dilations=(1, 1, 1, 1),\n        strides=(1, 2, 2, 2),\n        norm_cfg=norm_cfg,\n        norm_eval=False,\n        style='pytorch',\n        contract_dilation=True),\n    neck=dict(\n        type='FPN',\n        in_channels=[256, 512, 1024, 2048],\n        out_channels=256,\n        num_outs=4),\n    decode_head=dict(\n        type='FPNHead',\n        in_channels=[256, 256, 256, 256],\n        in_index=[0, 1, 2, 3],\n        feature_strides=[4, 8, 16, 32],\n        channels=128,\n        dropout_ratio=0.1,\n        num_classes=19,\n        norm_cfg=norm_cfg,\n        align_corners=False,\n        loss_decode=dict(\n            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),\n    # model training and testing settings\n    train_cfg=dict(),\n    test_cfg=dict(mode='whole'))\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/configs_local/_base_/models/gcnet_r50-d8.py",
    "content": "# model settings\nnorm_cfg = dict(type='SyncBN', requires_grad=True)\nmodel = dict(\n    type='EncoderDecoder',\n    pretrained='open-mmlab://resnet50_v1c',\n    backbone=dict(\n        type='ResNetV1c',\n        depth=50,\n        num_stages=4,\n        out_indices=(0, 1, 2, 3),\n        dilations=(1, 1, 2, 4),\n        strides=(1, 2, 1, 1),\n        norm_cfg=norm_cfg,\n        norm_eval=False,\n        style='pytorch',\n        contract_dilation=True),\n    decode_head=dict(\n        type='GCHead',\n        in_channels=2048,\n        in_index=3,\n        channels=512,\n        ratio=1 / 4.,\n        pooling_type='att',\n        fusion_types=('channel_add', ),\n        dropout_ratio=0.1,\n        num_classes=19,\n        norm_cfg=norm_cfg,\n        align_corners=False,\n        loss_decode=dict(\n            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),\n    auxiliary_head=dict(\n        type='FCNHead',\n        in_channels=1024,\n        in_index=2,\n        channels=256,\n        num_convs=1,\n        concat_input=False,\n        dropout_ratio=0.1,\n        num_classes=19,\n        norm_cfg=norm_cfg,\n        align_corners=False,\n        loss_decode=dict(\n            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),\n    # model training and testing settings\n    train_cfg=dict(),\n    test_cfg=dict(mode='whole'))\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/configs_local/_base_/models/lraspp_m-v3-d8.py",
    "content": "# model settings\nnorm_cfg = dict(type='SyncBN', eps=0.001, requires_grad=True)\nmodel = dict(\n    type='EncoderDecoder',\n    backbone=dict(\n        type='MobileNetV3',\n        arch='large',\n        out_indices=(1, 3, 16),\n        norm_cfg=norm_cfg),\n    decode_head=dict(\n        type='LRASPPHead',\n        in_channels=(16, 24, 960),\n        in_index=(0, 1, 2),\n        channels=128,\n        input_transform='multiple_select',\n        dropout_ratio=0.1,\n        num_classes=19,\n        norm_cfg=norm_cfg,\n        act_cfg=dict(type='ReLU'),\n        align_corners=False,\n        loss_decode=dict(\n            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),\n    # model training and testing settings\n    train_cfg=dict(),\n    test_cfg=dict(mode='whole'))\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/configs_local/_base_/models/nonlocal_r50-d8.py",
    "content": "# model settings\nnorm_cfg = dict(type='SyncBN', requires_grad=True)\nmodel = dict(\n    type='EncoderDecoder',\n    pretrained='open-mmlab://resnet50_v1c',\n    backbone=dict(\n        type='ResNetV1c',\n        depth=50,\n        num_stages=4,\n        out_indices=(0, 1, 2, 3),\n        dilations=(1, 1, 2, 4),\n        strides=(1, 2, 1, 1),\n        norm_cfg=norm_cfg,\n        norm_eval=False,\n        style='pytorch',\n        contract_dilation=True),\n    decode_head=dict(\n        type='NLHead',\n        in_channels=2048,\n        in_index=3,\n        channels=512,\n        dropout_ratio=0.1,\n        reduction=2,\n        use_scale=True,\n        mode='embedded_gaussian',\n        num_classes=19,\n        norm_cfg=norm_cfg,\n        align_corners=False,\n        loss_decode=dict(\n            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),\n    auxiliary_head=dict(\n        type='FCNHead',\n        in_channels=1024,\n        in_index=2,\n        channels=256,\n        num_convs=1,\n        concat_input=False,\n        dropout_ratio=0.1,\n        num_classes=19,\n        norm_cfg=norm_cfg,\n        align_corners=False,\n        loss_decode=dict(\n            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),\n    # model training and testing settings\n    train_cfg=dict(),\n    test_cfg=dict(mode='whole'))\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/configs_local/_base_/models/ocrnet_hr18.py",
    "content": "# model settings\nnorm_cfg = dict(type='SyncBN', requires_grad=True)\nmodel = dict(\n    type='CascadeEncoderDecoder',\n    num_stages=2,\n    pretrained='open-mmlab://msra/hrnetv2_w18',\n    backbone=dict(\n        type='HRNet',\n        norm_cfg=norm_cfg,\n        norm_eval=False,\n        extra=dict(\n            stage1=dict(\n                num_modules=1,\n                num_branches=1,\n                block='BOTTLENECK',\n                num_blocks=(4, ),\n                num_channels=(64, )),\n            stage2=dict(\n                num_modules=1,\n                num_branches=2,\n                block='BASIC',\n                num_blocks=(4, 4),\n                num_channels=(18, 36)),\n            stage3=dict(\n                num_modules=4,\n                num_branches=3,\n                block='BASIC',\n                num_blocks=(4, 4, 4),\n                num_channels=(18, 36, 72)),\n            stage4=dict(\n                num_modules=3,\n                num_branches=4,\n                block='BASIC',\n                num_blocks=(4, 4, 4, 4),\n                num_channels=(18, 36, 72, 144)))),\n    decode_head=[\n        dict(\n            type='FCNHead',\n            in_channels=[18, 36, 72, 144],\n            channels=sum([18, 36, 72, 144]),\n            in_index=(0, 1, 2, 3),\n            input_transform='resize_concat',\n            kernel_size=1,\n            num_convs=1,\n            concat_input=False,\n            dropout_ratio=-1,\n            num_classes=19,\n            norm_cfg=norm_cfg,\n            align_corners=False,\n            loss_decode=dict(\n                type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),\n        dict(\n            type='OCRHead',\n            in_channels=[18, 36, 72, 144],\n            in_index=(0, 1, 2, 3),\n            input_transform='resize_concat',\n            channels=512,\n            ocr_channels=256,\n            dropout_ratio=-1,\n            num_classes=19,\n            norm_cfg=norm_cfg,\n            align_corners=False,\n            loss_decode=dict(\n                type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),\n    ],\n    # model training and testing settings\n    train_cfg=dict(),\n    test_cfg=dict(mode='whole'))\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/configs_local/_base_/models/ocrnet_r50-d8.py",
    "content": "# model settings\nnorm_cfg = dict(type='SyncBN', requires_grad=True)\nmodel = dict(\n    type='CascadeEncoderDecoder',\n    num_stages=2,\n    pretrained='open-mmlab://resnet50_v1c',\n    backbone=dict(\n        type='ResNetV1c',\n        depth=50,\n        num_stages=4,\n        out_indices=(0, 1, 2, 3),\n        dilations=(1, 1, 2, 4),\n        strides=(1, 2, 1, 1),\n        norm_cfg=norm_cfg,\n        norm_eval=False,\n        style='pytorch',\n        contract_dilation=True),\n    decode_head=[\n        dict(\n            type='FCNHead',\n            in_channels=1024,\n            in_index=2,\n            channels=256,\n            num_convs=1,\n            concat_input=False,\n            dropout_ratio=0.1,\n            num_classes=19,\n            norm_cfg=norm_cfg,\n            align_corners=False,\n            loss_decode=dict(\n                type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),\n        dict(\n            type='OCRHead',\n            in_channels=2048,\n            in_index=3,\n            channels=512,\n            ocr_channels=256,\n            dropout_ratio=0.1,\n            num_classes=19,\n            norm_cfg=norm_cfg,\n            align_corners=False,\n            loss_decode=dict(\n                type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0))\n    ],\n    # model training and testing settings\n    train_cfg=dict(),\n    test_cfg=dict(mode='whole'))\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/configs_local/_base_/models/pointrend_r50.py",
    "content": "# model settings\nnorm_cfg = dict(type='SyncBN', requires_grad=True)\nmodel = dict(\n    type='CascadeEncoderDecoder',\n    num_stages=2,\n    pretrained='open-mmlab://resnet50_v1c',\n    backbone=dict(\n        type='ResNetV1c',\n        depth=50,\n        num_stages=4,\n        out_indices=(0, 1, 2, 3),\n        dilations=(1, 1, 1, 1),\n        strides=(1, 2, 2, 2),\n        norm_cfg=norm_cfg,\n        norm_eval=False,\n        style='pytorch',\n        contract_dilation=True),\n    neck=dict(\n        type='FPN',\n        in_channels=[256, 512, 1024, 2048],\n        out_channels=256,\n        num_outs=4),\n    decode_head=[\n        dict(\n            type='FPNHead',\n            in_channels=[256, 256, 256, 256],\n            in_index=[0, 1, 2, 3],\n            feature_strides=[4, 8, 16, 32],\n            channels=128,\n            dropout_ratio=-1,\n            num_classes=19,\n            norm_cfg=norm_cfg,\n            align_corners=False,\n            loss_decode=dict(\n                type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),\n        dict(\n            type='PointHead',\n            in_channels=[256],\n            in_index=[0],\n            channels=256,\n            num_fcs=3,\n            coarse_pred_each_layer=True,\n            dropout_ratio=-1,\n            num_classes=19,\n            align_corners=False,\n            loss_decode=dict(\n                type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0))\n    ],\n    # model training and testing settings\n    train_cfg=dict(\n        num_points=2048, oversample_ratio=3, importance_sample_ratio=0.75),\n    test_cfg=dict(\n        mode='whole',\n        subdivision_steps=2,\n        subdivision_num_points=8196,\n        scale_factor=2))\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/configs_local/_base_/models/psanet_r50-d8.py",
    "content": "# model settings\nnorm_cfg = dict(type='SyncBN', requires_grad=True)\nmodel = dict(\n    type='EncoderDecoder',\n    pretrained='open-mmlab://resnet50_v1c',\n    backbone=dict(\n        type='ResNetV1c',\n        depth=50,\n        num_stages=4,\n        out_indices=(0, 1, 2, 3),\n        dilations=(1, 1, 2, 4),\n        strides=(1, 2, 1, 1),\n        norm_cfg=norm_cfg,\n        norm_eval=False,\n        style='pytorch',\n        contract_dilation=True),\n    decode_head=dict(\n        type='PSAHead',\n        in_channels=2048,\n        in_index=3,\n        channels=512,\n        mask_size=(97, 97),\n        psa_type='bi-direction',\n        compact=False,\n        shrink_factor=2,\n        normalization_factor=1.0,\n        psa_softmax=True,\n        dropout_ratio=0.1,\n        num_classes=19,\n        norm_cfg=norm_cfg,\n        align_corners=False,\n        loss_decode=dict(\n            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),\n    auxiliary_head=dict(\n        type='FCNHead',\n        in_channels=1024,\n        in_index=2,\n        channels=256,\n        num_convs=1,\n        concat_input=False,\n        dropout_ratio=0.1,\n        num_classes=19,\n        norm_cfg=norm_cfg,\n        align_corners=False,\n        loss_decode=dict(\n            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),\n    # model training and testing settings\n    train_cfg=dict(),\n    test_cfg=dict(mode='whole'))\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/configs_local/_base_/models/pspnet_r50-d8.py",
    "content": "# model settings\nnorm_cfg = dict(type='SyncBN', requires_grad=True)\nmodel = dict(\n    type='EncoderDecoder',\n    pretrained='open-mmlab://resnet50_v1c',\n    backbone=dict(\n        type='ResNetV1c',\n        depth=50,\n        num_stages=4,\n        out_indices=(0, 1, 2, 3),\n        dilations=(1, 1, 2, 4),\n        strides=(1, 2, 1, 1),\n        norm_cfg=norm_cfg,\n        norm_eval=False,\n        style='pytorch',\n        contract_dilation=True),\n    decode_head=dict(\n        type='PSPHead',\n        in_channels=2048,\n        in_index=3,\n        channels=512,\n        pool_scales=(1, 2, 3, 6),\n        dropout_ratio=0.1,\n        num_classes=19,\n        norm_cfg=norm_cfg,\n        align_corners=False,\n        loss_decode=dict(\n            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),\n    auxiliary_head=dict(\n        type='FCNHead',\n        in_channels=1024,\n        in_index=2,\n        channels=256,\n        num_convs=1,\n        concat_input=False,\n        dropout_ratio=0.1,\n        num_classes=19,\n        norm_cfg=norm_cfg,\n        align_corners=False,\n        loss_decode=dict(\n            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),\n    # model training and testing settings\n    train_cfg=dict(),\n    test_cfg=dict(mode='whole'))\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/configs_local/_base_/models/pspnet_unet_s5-d16.py",
    "content": "# model settings\nnorm_cfg = dict(type='SyncBN', requires_grad=True)\nmodel = dict(\n    type='EncoderDecoder',\n    pretrained=None,\n    backbone=dict(\n        type='UNet',\n        in_channels=3,\n        base_channels=64,\n        num_stages=5,\n        strides=(1, 1, 1, 1, 1),\n        enc_num_convs=(2, 2, 2, 2, 2),\n        dec_num_convs=(2, 2, 2, 2),\n        downsamples=(True, True, True, True),\n        enc_dilations=(1, 1, 1, 1, 1),\n        dec_dilations=(1, 1, 1, 1),\n        with_cp=False,\n        conv_cfg=None,\n        norm_cfg=norm_cfg,\n        act_cfg=dict(type='ReLU'),\n        upsample_cfg=dict(type='InterpConv'),\n        norm_eval=False),\n    decode_head=dict(\n        type='PSPHead',\n        in_channels=64,\n        in_index=4,\n        channels=16,\n        pool_scales=(1, 2, 3, 6),\n        dropout_ratio=0.1,\n        num_classes=2,\n        norm_cfg=norm_cfg,\n        align_corners=False,\n        loss_decode=dict(\n            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),\n    auxiliary_head=dict(\n        type='FCNHead',\n        in_channels=128,\n        in_index=3,\n        channels=64,\n        num_convs=1,\n        concat_input=False,\n        dropout_ratio=0.1,\n        num_classes=2,\n        norm_cfg=norm_cfg,\n        align_corners=False,\n        loss_decode=dict(\n            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),\n    # model training and testing settings\n    train_cfg=dict(),\n    test_cfg=dict(mode='slide', crop_size=256, stride=170))\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/configs_local/_base_/models/upernet_cae.py",
    "content": "# --------------------------------------------------------\n# BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254)\n# Github source: https://github.com/microsoft/unilm/tree/master/beit\n# Copyright (c) 2021 Microsoft\n# Licensed under The MIT License [see LICENSE for details]\n# By Hangbo Bao\n# Based on timm, mmseg, setr, xcit and swin code bases\n# https://github.com/rwightman/pytorch-image-models/tree/master/timm\n# https://github.com/fudan-zvg/SETR\n# https://github.com/facebookresearch/xcit/\n# https://github.com/microsoft/Swin-Transformer\n# --------------------------------------------------------'\nnorm_cfg = dict(type='SyncBN', requires_grad=True)\nmodel = dict(\n    type='EncoderDecoder',\n    pretrained=None,\n    backbone=dict(\n        type='XCiT',\n        patch_size=16,\n        embed_dim=384,\n        depth=12,\n        num_heads=8,\n        mlp_ratio=4,\n        qkv_bias=True,\n        use_abs_pos_emb=True,\n        use_rel_pos_bias=False,\n    ),\n    decode_head=dict(\n        type='UPerHead',\n        in_channels=[384, 384, 384, 384],\n        in_index=[0, 1, 2, 3],\n        pool_scales=(1, 2, 3, 6),\n        channels=512,\n        dropout_ratio=0.1,\n        num_classes=19,\n        norm_cfg=norm_cfg,\n        align_corners=False,\n        loss_decode=dict(\n            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),\n    auxiliary_head=dict(\n        type='FCNHead',\n        in_channels=384,\n        in_index=2,\n        channels=256,\n        num_convs=1,\n        concat_input=False,\n        dropout_ratio=0.1,\n        num_classes=19,\n        norm_cfg=norm_cfg,\n        align_corners=False,\n        loss_decode=dict(\n            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),\n    # model training and testing settings\n    train_cfg=dict(),\n    test_cfg=dict(mode='whole'))\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/configs_local/_base_/models/upernet_r50.py",
    "content": "# model settings\nnorm_cfg = dict(type='SyncBN', requires_grad=True)\nmodel = dict(\n    type='EncoderDecoder',\n    pretrained='open-mmlab://resnet50_v1c',\n    backbone=dict(\n        type='ResNetV1c',\n        depth=50,\n        num_stages=4,\n        out_indices=(0, 1, 2, 3),\n        dilations=(1, 1, 1, 1),\n        strides=(1, 2, 2, 2),\n        norm_cfg=norm_cfg,\n        norm_eval=False,\n        style='pytorch',\n        contract_dilation=True),\n    decode_head=dict(\n        type='UPerHead',\n        in_channels=[256, 512, 1024, 2048],\n        in_index=[0, 1, 2, 3],\n        pool_scales=(1, 2, 3, 6),\n        channels=512,\n        dropout_ratio=0.1,\n        num_classes=19,\n        norm_cfg=norm_cfg,\n        align_corners=False,\n        loss_decode=dict(\n            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),\n    auxiliary_head=dict(\n        type='FCNHead',\n        in_channels=1024,\n        in_index=2,\n        channels=256,\n        num_convs=1,\n        concat_input=False,\n        dropout_ratio=0.1,\n        num_classes=19,\n        norm_cfg=norm_cfg,\n        align_corners=False,\n        loss_decode=dict(\n            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),\n    # model training and testing settings\n    train_cfg=dict(),\n    test_cfg=dict(mode='whole'))\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/configs_local/_base_/schedules/schedule_160k.py",
    "content": "# optimizer\noptimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)\noptimizer_config = dict()\n# learning policy\nlr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False)\n# runtime settings\nrunner = dict(type='IterBasedRunner', max_iters=160000)\ncheckpoint_config = dict(by_epoch=False, interval=16000)\nevaluation = dict(interval=16000, metric='mIoU')\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/configs_local/_base_/schedules/schedule_20k.py",
    "content": "# optimizer\noptimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)\noptimizer_config = dict()\n# learning policy\nlr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False)\n# runtime settings\nrunner = dict(type='IterBasedRunner', max_iters=20000)\ncheckpoint_config = dict(by_epoch=False, interval=2000)\nevaluation = dict(interval=2000, metric='mIoU')\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/configs_local/_base_/schedules/schedule_320k.py",
    "content": "# optimizer\noptimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)\noptimizer_config = dict()\n# learning policy\nlr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False)\n# runtime settings\nrunner = dict(type='IterBasedRunner', max_iters=320000)\ncheckpoint_config = dict(by_epoch=False, interval=32000)\nevaluation = dict(interval=32000, metric='mIoU')\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/configs_local/_base_/schedules/schedule_40k.py",
    "content": "# optimizer\noptimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)\noptimizer_config = dict()\n# learning policy\nlr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False)\n# runtime settings\nrunner = dict(type='IterBasedRunner', max_iters=40000)\ncheckpoint_config = dict(by_epoch=False, interval=4000)\nevaluation = dict(interval=4000, metric='mIoU')\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/configs_local/_base_/schedules/schedule_80k.py",
    "content": "# optimizer\noptimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)\noptimizer_config = dict()\n# learning policy\nlr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False)\n# runtime settings\nrunner = dict(type='IterBasedRunner', max_iters=80000)\ncheckpoint_config = dict(by_epoch=False, interval=8000)\nevaluation = dict(interval=8000, metric='mIoU')\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/configs_local/beit/upernet_beit_base_12_512_slide_160k_ade20k_pt_4e-4.py",
    "content": "# --------------------------------------------------------\n# BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254)\n# Github source: https://github.com/microsoft/unilm/tree/master/beit\n# Copyright (c) 2021 Microsoft\n# Licensed under The MIT License [see LICENSE for details]\n# By Hangbo Bao\n# Based on timm, mmseg, setr, xcit and swin code bases\n# https://github.com/rwightman/pytorch-image-models/tree/master/timm\n# https://github.com/fudan-zvg/SETR\n# https://github.com/facebookresearch/xcit/\n# https://github.com/microsoft/Swin-Transformer\n# --------------------------------------------------------'\n_base_ = [\n    '../../_base_/models/upernet_beit.py', '../../_base_/datasets/ade20k.py',\n    '../../_base_/default_runtime.py', '../../_base_/schedules/schedule_160k.py'\n]\ncrop_size = (512, 512)\n\nmodel = dict(\n    backbone=dict(\n        type='BEiT',\n        img_size=512,\n        patch_size=16,\n        embed_dim=768,\n        depth=12,\n        num_heads=12,\n        mlp_ratio=4,\n        qkv_bias=True,\n        use_abs_pos_emb=False,\n        use_sincos_pos_embed=False,\n        use_rel_pos_bias=True,\n        init_values=0.1,\n        drop_path_rate=0.1,\n        out_indices=[3, 5, 7, 11]\n    ),\n    decode_head=dict(\n        in_channels=[768, 768, 768, 768],\n        num_classes=150,\n        channels=768,\n    ),\n    auxiliary_head=dict(\n        in_channels=768,\n        num_classes=150\n    ), \n    test_cfg = dict(mode='slide', crop_size=crop_size, stride=(341, 341))\n)\n\n# AdamW optimizer, no weight decay for position embedding & layer norm in backbone\n# optimizer = dict(_delete_=True, type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01,\n#                  paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.),\n#                                                  'relative_position_bias_table': dict(decay_mult=0.),\n#                                                  'norm': dict(decay_mult=0.)}))\n\noptimizer = dict(_delete_=True, type='AdamW', lr=4e-4, betas=(0.9, 0.999), weight_decay=0.05,\n                 constructor='LayerDecayOptimizerConstructor', \n                 paramwise_cfg=dict(num_layers=12, layer_decay_rate=0.65))\n\nlr_config = dict(_delete_=True, policy='poly',\n                 warmup='linear',\n                 warmup_iters=1500,\n                 warmup_ratio=1e-6,\n                 power=1.0, min_lr=0.0, by_epoch=False)\n\n# By default, models are trained on 8 GPUs with 2 images per GPU\ndata=dict(samples_per_gpu=2)\n#img_norm_cfg = dict(\n#    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)\n#crop_size = (512, 512)\n## test_cfg = dict(mode='slide', crop_size=crop_size, stride=(341, 341))\nfind_unused_parameters = True\n\n#test_pipeline = [\n#    dict(type='LoadImageFromFile'),\n#    dict(\n#        type='MultiScaleFlipAug',\n#        img_scale=(2048, 512),\n#        img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],\n#        flip=True,\n#        transforms=[\n#            dict(type='SETR_Resize', keep_ratio=True,\n#                 crop_size=crop_size, setr_multi_scale=True),\n#            dict(type='RandomFlip'),\n#            dict(type='Normalize', **img_norm_cfg),\n#            dict(type='ImageToTensor', keys=['img']),\n#            dict(type='Collect', keys=['img']),\n#        ])\n#]\n#data = dict(\n#    val=dict(pipeline=test_pipeline),\n#    test=dict(pipeline=test_pipeline), \n#    samples_per_gpu=2, \n#)\n\nrunner = dict(type='IterBasedRunnerAmp')\ncheckpoint_config = dict(by_epoch=False, interval=8000)\n\n# do not use mmdet version fp16\nfp16 = None\noptimizer_config = dict(\n    type=\"DistOptimizerHook\",\n    update_interval=1,\n    grad_clip=None,\n    coalesce=True,\n    bucket_size_mb=-1,\n    use_fp16=True,\n)\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/configs_local/cae/upernet/upernet_cae_base_12_512_slide_160k_ade20k_pt_1e-4.py",
    "content": "# --------------------------------------------------------\n# BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254)\n# Github source: https://github.com/microsoft/unilm/tree/master/beit\n# Copyright (c) 2021 Microsoft\n# Licensed under The MIT License [see LICENSE for details]\n# By Hangbo Bao\n# Based on timm, mmseg, setr, xcit and swin code bases\n# https://github.com/rwightman/pytorch-image-models/tree/master/timm\n# https://github.com/fudan-zvg/SETR\n# https://github.com/facebookresearch/xcit/\n# https://github.com/microsoft/Swin-Transformer\n# --------------------------------------------------------'\n_base_ = [\n    '../../_base_/models/upernet_cae.py', '../../_base_/datasets/ade20k.py',\n    '../../_base_/default_runtime.py', '../../_base_/schedules/schedule_160k.py'\n]\ncrop_size = (512, 512)\n\nmodel = dict(\n    backbone=dict(\n        type='CAE',\n        img_size=512,\n        patch_size=16,\n        embed_dim=768,\n        depth=12,\n        num_heads=12,\n        mlp_ratio=4,\n        qkv_bias=True,\n        use_abs_pos_emb=False,\n        use_rel_pos_bias=True,\n        init_values=0.1,\n        drop_path_rate=0.1,\n        out_indices=[3, 5, 7, 11]\n    ),\n    decode_head=dict(\n        in_channels=[768, 768, 768, 768],\n        num_classes=150,\n        channels=768,\n    ),\n    auxiliary_head=dict(\n        in_channels=768,\n        num_classes=150\n    ), \n    test_cfg = dict(mode='slide', crop_size=crop_size, stride=(341, 341))\n)\n\noptimizer = dict(_delete_=True, type='AdamW', lr=1e-4, betas=(0.9, 0.999), weight_decay=0.05,\n                 constructor='LayerDecayOptimizerConstructor', \n                 paramwise_cfg=dict(num_layers=12, layer_decay_rate=0.65))\n\nlr_config = dict(_delete_=True, policy='poly',\n                 warmup='linear',\n                 warmup_iters=1500,\n                 warmup_ratio=1e-6,\n                 power=1.0, min_lr=0.0, by_epoch=False)\n\n# By default, models are trained on 8 GPUs with 2 images per GPU\ndata=dict(samples_per_gpu=2)\nfind_unused_parameters = True\n\n#test_pipeline = [\n#    dict(type='LoadImageFromFile'),\n#    dict(\n#        type='MultiScaleFlipAug',\n#        img_scale=(2048, 512),\n#        img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],\n#        flip=True,\n#        transforms=[\n#            dict(type='SETR_Resize', keep_ratio=True,\n#                 crop_size=crop_size, setr_multi_scale=True),\n#            dict(type='RandomFlip'),\n#            dict(type='Normalize', **img_norm_cfg),\n#            dict(type='ImageToTensor', keys=['img']),\n#            dict(type='Collect', keys=['img']),\n#        ])\n#]\n#data = dict(\n#    val=dict(pipeline=test_pipeline),\n#    test=dict(pipeline=test_pipeline), \n#    samples_per_gpu=2, \n#)\n\nrunner = dict(type='IterBasedRunnerAmp')\ncheckpoint_config = dict(by_epoch=False, interval=8000)\n\n# do not use mmdet version fp16\nfp16 = None\noptimizer_config = dict(\n    type=\"DistOptimizerHook\",\n    update_interval=1,\n    grad_clip=None,\n    coalesce=True,\n    bucket_size_mb=-1,\n    use_fp16=True,\n)\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/configs_local/cae/upernet/upernet_cae_base_12_512_slide_160k_ade20k_pt_2e-4.py",
    "content": "# --------------------------------------------------------\n# BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254)\n# Github source: https://github.com/microsoft/unilm/tree/master/beit\n# Copyright (c) 2021 Microsoft\n# Licensed under The MIT License [see LICENSE for details]\n# By Hangbo Bao\n# Based on timm, mmseg, setr, xcit and swin code bases\n# https://github.com/rwightman/pytorch-image-models/tree/master/timm\n# https://github.com/fudan-zvg/SETR\n# https://github.com/facebookresearch/xcit/\n# https://github.com/microsoft/Swin-Transformer\n# --------------------------------------------------------'\n_base_ = [\n    '../../_base_/models/upernet_cae.py', '../../_base_/datasets/ade20k.py',\n    '../../_base_/default_runtime.py', '../../_base_/schedules/schedule_160k.py'\n]\ncrop_size = (512, 512)\n\nmodel = dict(\n    backbone=dict(\n        type='CAE',\n        img_size=512,\n        patch_size=16,\n        embed_dim=768,\n        depth=12,\n        num_heads=12,\n        mlp_ratio=4,\n        qkv_bias=True,\n        use_abs_pos_emb=False,\n        use_rel_pos_bias=True,\n        init_values=0.1,\n        drop_path_rate=0.1,\n        out_indices=[3, 5, 7, 11]\n    ),\n    decode_head=dict(\n        in_channels=[768, 768, 768, 768],\n        num_classes=150,\n        channels=768,\n    ),\n    auxiliary_head=dict(\n        in_channels=768,\n        num_classes=150\n    ), \n    test_cfg = dict(mode='slide', crop_size=crop_size, stride=(341, 341))\n)\n\noptimizer = dict(_delete_=True, type='AdamW', lr=2e-4, betas=(0.9, 0.999), weight_decay=0.05,\n                 constructor='LayerDecayOptimizerConstructor', \n                 paramwise_cfg=dict(num_layers=12, layer_decay_rate=0.65))\n\nlr_config = dict(_delete_=True, policy='poly',\n                 warmup='linear',\n                 warmup_iters=1500,\n                 warmup_ratio=1e-6,\n                 power=1.0, min_lr=0.0, by_epoch=False)\n\n# By default, models are trained on 8 GPUs with 2 images per GPU\ndata=dict(samples_per_gpu=2)\nfind_unused_parameters = True\n\n#test_pipeline = [\n#    dict(type='LoadImageFromFile'),\n#    dict(\n#        type='MultiScaleFlipAug',\n#        img_scale=(2048, 512),\n#        img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],\n#        flip=True,\n#        transforms=[\n#            dict(type='SETR_Resize', keep_ratio=True,\n#                 crop_size=crop_size, setr_multi_scale=True),\n#            dict(type='RandomFlip'),\n#            dict(type='Normalize', **img_norm_cfg),\n#            dict(type='ImageToTensor', keys=['img']),\n#            dict(type='Collect', keys=['img']),\n#        ])\n#]\n#data = dict(\n#    val=dict(pipeline=test_pipeline),\n#    test=dict(pipeline=test_pipeline), \n#    samples_per_gpu=2, \n#)\n\nrunner = dict(type='IterBasedRunnerAmp')\ncheckpoint_config = dict(by_epoch=False, interval=8000)\n\n# do not use mmdet version fp16\nfp16 = None\noptimizer_config = dict(\n    type=\"DistOptimizerHook\",\n    update_interval=1,\n    grad_clip=None,\n    coalesce=True,\n    bucket_size_mb=-1,\n    use_fp16=True,\n)\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/configs_local/cae/upernet/upernet_cae_base_12_512_slide_160k_ade20k_pt_3e-4.py",
    "content": "# --------------------------------------------------------\n# BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254)\n# Github source: https://github.com/microsoft/unilm/tree/master/beit\n# Copyright (c) 2021 Microsoft\n# Licensed under The MIT License [see LICENSE for details]\n# By Hangbo Bao\n# Based on timm, mmseg, setr, xcit and swin code bases\n# https://github.com/rwightman/pytorch-image-models/tree/master/timm\n# https://github.com/fudan-zvg/SETR\n# https://github.com/facebookresearch/xcit/\n# https://github.com/microsoft/Swin-Transformer\n# --------------------------------------------------------'\n_base_ = [\n    '../../_base_/models/upernet_cae.py', '../../_base_/datasets/ade20k.py',\n    '../../_base_/default_runtime.py', '../../_base_/schedules/schedule_160k.py'\n]\ncrop_size = (512, 512)\n\nmodel = dict(\n    backbone=dict(\n        type='CAE',\n        img_size=512,\n        patch_size=16,\n        embed_dim=768,\n        depth=12,\n        num_heads=12,\n        mlp_ratio=4,\n        qkv_bias=True,\n        use_abs_pos_emb=False,\n        use_rel_pos_bias=True,\n        init_values=0.1,\n        drop_path_rate=0.1,\n        out_indices=[3, 5, 7, 11]\n    ),\n    decode_head=dict(\n        in_channels=[768, 768, 768, 768],\n        num_classes=150,\n        channels=768,\n    ),\n    auxiliary_head=dict(\n        in_channels=768,\n        num_classes=150\n    ), \n    test_cfg = dict(mode='slide', crop_size=crop_size, stride=(341, 341))\n)\n\noptimizer = dict(_delete_=True, type='AdamW', lr=3e-4, betas=(0.9, 0.999), weight_decay=0.05,\n                 constructor='LayerDecayOptimizerConstructor', \n                 paramwise_cfg=dict(num_layers=12, layer_decay_rate=0.65))\n\nlr_config = dict(_delete_=True, policy='poly',\n                 warmup='linear',\n                 warmup_iters=1500,\n                 warmup_ratio=1e-6,\n                 power=1.0, min_lr=0.0, by_epoch=False)\n\n# By default, models are trained on 8 GPUs with 2 images per GPU\ndata=dict(samples_per_gpu=2)\nfind_unused_parameters = True\n\n#test_pipeline = [\n#    dict(type='LoadImageFromFile'),\n#    dict(\n#        type='MultiScaleFlipAug',\n#        img_scale=(2048, 512),\n#        img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],\n#        flip=True,\n#        transforms=[\n#            dict(type='SETR_Resize', keep_ratio=True,\n#                 crop_size=crop_size, setr_multi_scale=True),\n#            dict(type='RandomFlip'),\n#            dict(type='Normalize', **img_norm_cfg),\n#            dict(type='ImageToTensor', keys=['img']),\n#            dict(type='Collect', keys=['img']),\n#        ])\n#]\n#data = dict(\n#    val=dict(pipeline=test_pipeline),\n#    test=dict(pipeline=test_pipeline), \n#    samples_per_gpu=2, \n#)\n\nrunner = dict(type='IterBasedRunnerAmp')\ncheckpoint_config = dict(by_epoch=False, interval=8000)\n\n# do not use mmdet version fp16\nfp16 = None\noptimizer_config = dict(\n    type=\"DistOptimizerHook\",\n    update_interval=1,\n    grad_clip=None,\n    coalesce=True,\n    bucket_size_mb=-1,\n    use_fp16=True,\n)\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/configs_local/cae/upernet/upernet_cae_large_24_512_slide_160k_ade20k_pt_decay095_4e-5_dp015.py",
    "content": "# --------------------------------------------------------\n# BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254)\n# Github source: https://github.com/microsoft/unilm/tree/master/beit\n# Copyright (c) 2021 Microsoft\n# Licensed under The MIT License [see LICENSE for details]\n# By Hangbo Bao\n# Based on timm, mmseg, setr, xcit and swin code bases\n# https://github.com/rwightman/pytorch-image-models/tree/master/timm\n# https://github.com/fudan-zvg/SETR\n# https://github.com/facebookresearch/xcit/\n# https://github.com/microsoft/Swin-Transformer\n# --------------------------------------------------------'\n_base_ = [\n    '../../_base_/models/upernet_cae.py', '../../_base_/datasets/ade20k.py',\n    '../../_base_/default_runtime.py', '../../_base_/schedules/schedule_160k.py'\n]\ncrop_size = (512, 512)\n\nmodel = dict(\n    backbone=dict(\n        type='CAE',\n        img_size=512,\n        patch_size=16,\n        embed_dim=1024,\n        depth=24,\n        num_heads=16,\n        mlp_ratio=4,\n        qkv_bias=True,\n        use_abs_pos_emb=False,\n        use_rel_pos_bias=True,\n        init_values=1e-5,\n        drop_path_rate=0.15,\n        out_indices=[7, 11, 15, 23],\n    ),\n    decode_head=dict(\n        in_channels=[1024, 1024, 1024, 1024],\n        num_classes=150,\n        channels=1024,\n    ),\n    auxiliary_head=dict(\n        in_channels=1024,\n        num_classes=150\n    ), \n    test_cfg = dict(mode='slide', crop_size=crop_size, stride=(341, 341))\n)\n\n# AdamW optimizer, no weight decay for position embedding & layer norm in backbone\n# optimizer = dict(_delete_=True, type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01,\n#                  paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.),\n#                                                  'relative_position_bias_table': dict(decay_mult=0.),\n#                                                  'norm': dict(decay_mult=0.)}))\n\noptimizer = dict(_delete_=True, type='AdamW', lr=4e-5, betas=(0.9, 0.999), weight_decay=0.05,\n                 constructor='LayerDecayOptimizerConstructor', \n                 paramwise_cfg=dict(num_layers=24, layer_decay_rate=0.95))\n\nlr_config = dict(_delete_=True, policy='poly',\n                 warmup='linear',\n                 warmup_iters=1500,\n                 warmup_ratio=1e-6,\n                 power=1.0, min_lr=0.0, by_epoch=False)\n\n# By default, models are trained on 8 GPUs with 2 images per GPU\ndata=dict(samples_per_gpu=2)\n#img_norm_cfg = dict(\n#    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)\n#crop_size = (512, 512)\n## test_cfg = dict(mode='slide', crop_size=crop_size, stride=(341, 341))\nfind_unused_parameters = True\n\n#test_pipeline = [\n#    dict(type='LoadImageFromFile'),\n#    dict(\n#        type='MultiScaleFlipAug',\n#        img_scale=(2048, 512),\n#        img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],\n#        flip=True,\n#        transforms=[\n#            dict(type='SETR_Resize', keep_ratio=True,\n#                 crop_size=crop_size, setr_multi_scale=True),\n#            dict(type='RandomFlip'),\n#            dict(type='Normalize', **img_norm_cfg),\n#            dict(type='ImageToTensor', keys=['img']),\n#            dict(type='Collect', keys=['img']),\n#        ])\n#]\n#data = dict(\n#    val=dict(pipeline=test_pipeline),\n#    test=dict(pipeline=test_pipeline), \n#    samples_per_gpu=2, \n#)\n\nrunner = dict(type='IterBasedRunnerAmp')\ncheckpoint_config = dict(by_epoch=False, interval=32000)\n\n# do not use mmdet version fp16\nfp16 = None\noptimizer_config = dict(\n    type=\"DistOptimizerHook\",\n    update_interval=1,\n    grad_clip=None,\n    coalesce=True,\n    bucket_size_mb=-1,\n    use_fp16=True,\n)\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/configs_local/mae/upernet_mae_large_12_512_slide_160k_ade20k_pt_4e-4.py",
    "content": "# --------------------------------------------------------\n# BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254)\n# Github source: https://github.com/microsoft/unilm/tree/master/beit\n# Copyright (c) 2021 Microsoft\n# Licensed under The MIT License [see LICENSE for details]\n# By Hangbo Bao\n# Based on timm, mmseg, setr, xcit and swin code bases\n# https://github.com/rwightman/pytorch-image-models/tree/master/timm\n# https://github.com/fudan-zvg/SETR\n# https://github.com/facebookresearch/xcit/\n# https://github.com/microsoft/Swin-Transformer\n# --------------------------------------------------------'\n_base_ = [\n    '../../_base_/models/upernet_beit.py', '../../_base_/datasets/ade20k.py',\n    '../../_base_/default_runtime.py', '../../_base_/schedules/schedule_160k.py'\n]\ncrop_size = (512, 512)\n\nmodel = dict(\n    backbone=dict(\n        type='MAE',\n        img_size=512,\n        patch_size=16,\n        embed_dim=1024,\n        depth=24,\n        num_heads=16,\n        mlp_ratio=4,\n        qkv_bias=True,\n        use_abs_pos_emb=False,\n        use_rel_pos_bias=True,\n        init_values=1,\n        drop_path_rate=0.2,\n        out_indices=[7, 11, 15, 23],\n    ),\n    decode_head=dict(\n        in_channels=[1024, 1024, 1024, 1024],\n        num_classes=150,\n        channels=1024,\n    ),\n    auxiliary_head=dict(\n        in_channels=1024,\n        num_classes=150\n    ), \n    test_cfg = dict(mode='slide', crop_size=crop_size, stride=(341, 341))\n)\n\n# AdamW optimizer, no weight decay for position embedding & layer norm in backbone\n# optimizer = dict(_delete_=True, type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01,\n#                  paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.),\n#                                                  'relative_position_bias_table': dict(decay_mult=0.),\n#                                                  'norm': dict(decay_mult=0.)}))\n\noptimizer = dict(_delete_=True, type='AdamW', lr=4e-4, betas=(0.9, 0.999), weight_decay=0.05,\n                 constructor='LayerDecayOptimizerConstructor', \n                 paramwise_cfg=dict(num_layers=24, layer_decay_rate=0.65))\n\nlr_config = dict(_delete_=True, policy='poly',\n                 warmup='linear',\n                 warmup_iters=1500,\n                 warmup_ratio=1e-6,\n                 power=1.0, min_lr=0.0, by_epoch=False)\n\n# By default, models are trained on 8 GPUs with 2 images per GPU\ndata=dict(samples_per_gpu=2)\n#img_norm_cfg = dict(\n#    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)\n#crop_size = (512, 512)\n## test_cfg = dict(mode='slide', crop_size=crop_size, stride=(341, 341))\nfind_unused_parameters = True\n\n#test_pipeline = [\n#    dict(type='LoadImageFromFile'),\n#    dict(\n#        type='MultiScaleFlipAug',\n#        img_scale=(2048, 512),\n#        img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],\n#        flip=True,\n#        transforms=[\n#            dict(type='SETR_Resize', keep_ratio=True,\n#                 crop_size=crop_size, setr_multi_scale=True),\n#            dict(type='RandomFlip'),\n#            dict(type='Normalize', **img_norm_cfg),\n#            dict(type='ImageToTensor', keys=['img']),\n#            dict(type='Collect', keys=['img']),\n#        ])\n#]\n#data = dict(\n#    val=dict(pipeline=test_pipeline),\n#    test=dict(pipeline=test_pipeline), \n#    samples_per_gpu=2, \n#)\n\nrunner = dict(type='IterBasedRunnerAmp')\ncheckpoint_config = dict(by_epoch=False, interval=32000)\n\n# do not use mmdet version fp16\nfp16 = None\noptimizer_config = dict(\n    type=\"DistOptimizerHook\",\n    update_interval=1,\n    grad_clip=None,\n    coalesce=True,\n    bucket_size_mb=-1,\n    use_fp16=True,\n)\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/mmcv_custom/__init__.py",
    "content": "# -*- coding: utf-8 -*-\n\nfrom .checkpoint import load_checkpoint\nfrom .layer_decay_optimizer_constructor import LayerDecayOptimizerConstructor\nfrom .resize_transform import SETR_Resize\nfrom .apex_runner.optimizer import DistOptimizerHook\nfrom .train_api import train_segmentor\n\n__all__ = ['load_checkpoint', 'LayerDecayOptimizerConstructor', 'SETR_Resize', 'DistOptimizerHook', 'train_segmentor']\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/mmcv_custom/apex_runner/__init__.py",
    "content": "# Copyright (c) Open-MMLab. All rights reserved.\nfrom .checkpoint import save_checkpoint\nfrom .apex_iter_based_runner import IterBasedRunnerAmp\n\n\n__all__ = [\n    'save_checkpoint', 'IterBasedRunnerAmp', \n]\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/mmcv_custom/apex_runner/apex_iter_based_runner.py",
    "content": "# Copyright (c) Open-MMLab. All rights reserved.\nimport os.path as osp\nimport platform\nimport shutil\n\nimport torch\nfrom torch.optim import Optimizer\n\nimport mmcv\nfrom mmcv.runner import RUNNERS, IterBasedRunner\nfrom .checkpoint import save_checkpoint\n\ntry:\n    import apex\nexcept:\n    print('apex is not installed')\n\n\n@RUNNERS.register_module()\nclass IterBasedRunnerAmp(IterBasedRunner):\n    \"\"\"Iteration-based Runner with AMP support.\n\n    This runner train models iteration by iteration.\n    \"\"\"\n\n    def save_checkpoint(self,\n                        out_dir,\n                        filename_tmpl='iter_{}.pth',\n                        meta=None,\n                        save_optimizer=True,\n                        create_symlink=False):\n        \"\"\"Save checkpoint to file.\n\n        Args:\n            out_dir (str): Directory to save checkpoint files.\n            filename_tmpl (str, optional): Checkpoint file template.\n                Defaults to 'iter_{}.pth'.\n            meta (dict, optional): Metadata to be saved in checkpoint.\n                Defaults to None.\n            save_optimizer (bool, optional): Whether save optimizer.\n                Defaults to True.\n            create_symlink (bool, optional): Whether create symlink to the\n                latest checkpoint file. Defaults to True.\n        \"\"\"\n        if meta is None:\n            meta = dict(iter=self.iter + 1, epoch=self.epoch + 1)\n        elif isinstance(meta, dict):\n            meta.update(iter=self.iter + 1, epoch=self.epoch + 1)\n        else:\n            raise TypeError(\n                f'meta should be a dict or None, but got {type(meta)}')\n        if self.meta is not None:\n            meta.update(self.meta)\n\n        filename = filename_tmpl.format(self.iter + 1)\n        filepath = osp.join(out_dir, filename)\n        optimizer = self.optimizer if save_optimizer else None\n        save_checkpoint(self.model, filepath, optimizer=optimizer, meta=meta)\n        # in some environments, `os.symlink` is not supported, you may need to\n        # set `create_symlink` to False\n        # if create_symlink:\n        #     dst_file = osp.join(out_dir, 'latest.pth')\n        #     if platform.system() != 'Windows':\n        #         mmcv.symlink(filename, dst_file)\n        #     else:\n        #         shutil.copy(filepath, dst_file)\n\n    def resume(self,\n               checkpoint,\n               resume_optimizer=True,\n               map_location='default'):\n        if map_location == 'default':\n            if torch.cuda.is_available():\n                device_id = torch.cuda.current_device()\n                checkpoint = self.load_checkpoint(\n                    checkpoint,\n                    map_location=lambda storage, loc: storage.cuda(device_id))\n            else:\n                checkpoint = self.load_checkpoint(checkpoint)\n        else:\n            checkpoint = self.load_checkpoint(\n                checkpoint, map_location=map_location)\n\n        self._epoch = checkpoint['meta']['epoch']\n        self._iter = checkpoint['meta']['iter']\n        self._inner_iter = checkpoint['meta']['iter']\n        if 'optimizer' in checkpoint and resume_optimizer:\n            if isinstance(self.optimizer, Optimizer):\n                self.optimizer.load_state_dict(checkpoint['optimizer'])\n            elif isinstance(self.optimizer, dict):\n                for k in self.optimizer.keys():\n                    self.optimizer[k].load_state_dict(\n                        checkpoint['optimizer'][k])\n            else:\n                raise TypeError(\n                    'Optimizer should be dict or torch.optim.Optimizer '\n                    f'but got {type(self.optimizer)}')\n\n        if 'amp' in checkpoint:\n            apex.amp.load_state_dict(checkpoint['amp'])\n            self.logger.info('load amp state dict')\n\n        self.logger.info(f'resumed from epoch: {self.epoch}, iter {self.iter}')\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/mmcv_custom/apex_runner/checkpoint.py",
    "content": "# Copyright (c) Open-MMLab. All rights reserved.\nimport os.path as osp\nimport time\nfrom tempfile import TemporaryDirectory\n\nimport torch\nfrom torch.optim import Optimizer\n\nimport mmcv\nfrom mmcv.parallel import is_module_wrapper\nfrom mmcv.runner.checkpoint import weights_to_cpu, get_state_dict\n\ntry:\n    import apex\nexcept:\n    print('apex is not installed')\n\n\ndef save_checkpoint(model, filename, optimizer=None, meta=None):\n    \"\"\"Save checkpoint to file.\n\n    The checkpoint will have 4 fields: ``meta``, ``state_dict`` and\n    ``optimizer``, ``amp``. By default ``meta`` will contain version\n    and time info.\n\n    Args:\n        model (Module): Module whose params are to be saved.\n        filename (str): Checkpoint filename.\n        optimizer (:obj:`Optimizer`, optional): Optimizer to be saved.\n        meta (dict, optional): Metadata to be saved in checkpoint.\n    \"\"\"\n    if meta is None:\n        meta = {}\n    elif not isinstance(meta, dict):\n        raise TypeError(f'meta must be a dict or None, but got {type(meta)}')\n    meta.update(mmcv_version=mmcv.__version__, time=time.asctime())\n\n    if is_module_wrapper(model):\n        model = model.module\n\n    if hasattr(model, 'CLASSES') and model.CLASSES is not None:\n        # save class name to the meta\n        meta.update(CLASSES=model.CLASSES)\n\n    checkpoint = {\n        'meta': meta,\n        'state_dict': weights_to_cpu(get_state_dict(model))\n    }\n    # save optimizer state dict in the checkpoint\n    if isinstance(optimizer, Optimizer):\n        checkpoint['optimizer'] = optimizer.state_dict()\n    elif isinstance(optimizer, dict):\n        checkpoint['optimizer'] = {}\n        for name, optim in optimizer.items():\n            checkpoint['optimizer'][name] = optim.state_dict()\n\n    # save amp state dict in the checkpoint\n    checkpoint['amp'] = apex.amp.state_dict()\n\n    if filename.startswith('pavi://'):\n        try:\n            from pavi import modelcloud\n            from pavi.exception import NodeNotFoundError\n        except ImportError:\n            raise ImportError(\n                'Please install pavi to load checkpoint from modelcloud.')\n        model_path = filename[7:]\n        root = modelcloud.Folder()\n        model_dir, model_name = osp.split(model_path)\n        try:\n            model = modelcloud.get(model_dir)\n        except NodeNotFoundError:\n            model = root.create_training_model(model_dir)\n        with TemporaryDirectory() as tmp_dir:\n            checkpoint_file = osp.join(tmp_dir, model_name)\n            with open(checkpoint_file, 'wb') as f:\n                torch.save(checkpoint, f)\n                f.flush()\n            model.create_file(checkpoint_file, name=model_name)\n    else:\n        mmcv.mkdir_or_exist(osp.dirname(filename))\n        # immediately flush buffer\n        with open(filename, 'wb') as f:\n            torch.save(checkpoint, f)\n            f.flush()\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/mmcv_custom/apex_runner/optimizer.py",
    "content": "from mmcv.runner import OptimizerHook, HOOKS\ntry:\n    import apex\nexcept:\n    print('apex is not installed')\n\n\n@HOOKS.register_module()\nclass DistOptimizerHook(OptimizerHook):\n    \"\"\"Optimizer hook for distributed training.\"\"\"\n\n    def __init__(self, update_interval=1, grad_clip=None, coalesce=True, bucket_size_mb=-1, use_fp16=False):\n        self.grad_clip = grad_clip\n        self.coalesce = coalesce\n        self.bucket_size_mb = bucket_size_mb\n        self.update_interval = update_interval\n        self.use_fp16 = use_fp16\n\n    def before_run(self, runner):\n        runner.optimizer.zero_grad()\n\n    def after_train_iter(self, runner):\n        runner.outputs['loss'] /= self.update_interval\n        if self.use_fp16:\n            with apex.amp.scale_loss(runner.outputs['loss'], runner.optimizer) as scaled_loss:\n                scaled_loss.backward()\n        else:\n            runner.outputs['loss'].backward()\n        if self.every_n_iters(runner, self.update_interval):\n            if self.grad_clip is not None:\n                self.clip_grads(runner.model.parameters())\n            runner.optimizer.step()\n            runner.optimizer.zero_grad()\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/mmcv_custom/checkpoint.py",
    "content": "# Copyright (c) Open-MMLab. All rights reserved.\nimport io\nimport os\nimport os.path as osp\nimport pkgutil\nimport time\nimport warnings\nfrom collections import OrderedDict\nfrom importlib import import_module\nfrom tempfile import TemporaryDirectory\n\nimport torch\nimport torchvision\nfrom torch.optim import Optimizer\nfrom torch.utils import model_zoo\nfrom torch.nn import functional as F\n\nimport mmcv\nfrom mmcv.fileio import FileClient\nfrom mmcv.fileio import load as load_file\nfrom mmcv.parallel import is_module_wrapper\nfrom mmcv.utils import mkdir_or_exist\nfrom mmcv.runner import get_dist_info\n\nfrom scipy import interpolate\nimport numpy as np\nimport math\n\nENV_MMCV_HOME = 'MMCV_HOME'\nENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'\nDEFAULT_CACHE_DIR = '~/.cache'\n\n\ndef _get_mmcv_home():\n    mmcv_home = os.path.expanduser(\n        os.getenv(\n            ENV_MMCV_HOME,\n            os.path.join(\n                os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'mmcv')))\n\n    mkdir_or_exist(mmcv_home)\n    return mmcv_home\n\n\ndef load_state_dict(module, state_dict, strict=False, logger=None):\n    \"\"\"Load state_dict to a module.\n\n    This method is modified from :meth:`torch.nn.Module.load_state_dict`.\n    Default value for ``strict`` is set to ``False`` and the message for\n    param mismatch will be shown even if strict is False.\n\n    Args:\n        module (Module): Module that receives the state_dict.\n        state_dict (OrderedDict): Weights.\n        strict (bool): whether to strictly enforce that the keys\n            in :attr:`state_dict` match the keys returned by this module's\n            :meth:`~torch.nn.Module.state_dict` function. Default: ``False``.\n        logger (:obj:`logging.Logger`, optional): Logger to log the error\n            message. If not specified, print function will be used.\n    \"\"\"\n    unexpected_keys = []\n    all_missing_keys = []\n    err_msg = []\n\n    metadata = getattr(state_dict, '_metadata', None)\n    state_dict = state_dict.copy()\n    if metadata is not None:\n        state_dict._metadata = metadata\n\n    # use _load_from_state_dict to enable checkpoint version control\n    def load(module, prefix=''):\n        # recursively check parallel module in case that the model has a\n        # complicated structure, e.g., nn.Module(nn.Module(DDP))\n        if is_module_wrapper(module):\n            module = module.module\n        local_metadata = {} if metadata is None else metadata.get(\n            prefix[:-1], {})\n        module._load_from_state_dict(state_dict, prefix, local_metadata, True,\n                                     all_missing_keys, unexpected_keys,\n                                     err_msg)\n        for name, child in module._modules.items():\n            if child is not None:\n                load(child, prefix + name + '.')\n\n    load(module)\n    load = None  # break load->load reference cycle\n\n    # ignore \"num_batches_tracked\" of BN layers\n    missing_keys = [\n        key for key in all_missing_keys if 'num_batches_tracked' not in key\n    ]\n\n    if unexpected_keys:\n        err_msg.append('unexpected key in source '\n                       f'state_dict: {\", \".join(unexpected_keys)}\\n')\n    if missing_keys:\n        err_msg.append(\n            f'missing keys in source state_dict: {\", \".join(missing_keys)}\\n')\n\n    rank, _ = get_dist_info()\n    if len(err_msg) > 0 and rank == 0:\n        err_msg.insert(\n            0, 'The model and loaded state dict do not match exactly\\n')\n        err_msg = '\\n'.join(err_msg)\n        if strict:\n            raise RuntimeError(err_msg)\n        elif logger is not None:\n            logger.warning(err_msg)\n        else:\n            print(err_msg)\n\n\ndef load_url_dist(url, model_dir=None, map_location=\"cpu\"):\n    \"\"\"In distributed setting, this function only download checkpoint at local\n    rank 0.\"\"\"\n    rank, world_size = get_dist_info()\n    rank = int(os.environ.get('LOCAL_RANK', rank))\n    if rank == 0:\n        checkpoint = model_zoo.load_url(url, model_dir=model_dir, map_location=map_location)\n    if world_size > 1:\n        torch.distributed.barrier()\n        if rank > 0:\n            checkpoint = model_zoo.load_url(url, model_dir=model_dir, map_location=map_location)\n    return checkpoint\n\n\ndef load_pavimodel_dist(model_path, map_location=None):\n    \"\"\"In distributed setting, this function only download checkpoint at local\n    rank 0.\"\"\"\n    try:\n        from pavi import modelcloud\n    except ImportError:\n        raise ImportError(\n            'Please install pavi to load checkpoint from modelcloud.')\n    rank, world_size = get_dist_info()\n    rank = int(os.environ.get('LOCAL_RANK', rank))\n    if rank == 0:\n        model = modelcloud.get(model_path)\n        with TemporaryDirectory() as tmp_dir:\n            downloaded_file = osp.join(tmp_dir, model.name)\n            model.download(downloaded_file)\n            checkpoint = torch.load(downloaded_file, map_location=map_location)\n    if world_size > 1:\n        torch.distributed.barrier()\n        if rank > 0:\n            model = modelcloud.get(model_path)\n            with TemporaryDirectory() as tmp_dir:\n                downloaded_file = osp.join(tmp_dir, model.name)\n                model.download(downloaded_file)\n                checkpoint = torch.load(\n                    downloaded_file, map_location=map_location)\n    return checkpoint\n\n\ndef load_fileclient_dist(filename, backend, map_location):\n    \"\"\"In distributed setting, this function only download checkpoint at local\n    rank 0.\"\"\"\n    rank, world_size = get_dist_info()\n    rank = int(os.environ.get('LOCAL_RANK', rank))\n    allowed_backends = ['ceph']\n    if backend not in allowed_backends:\n        raise ValueError(f'Load from Backend {backend} is not supported.')\n    if rank == 0:\n        fileclient = FileClient(backend=backend)\n        buffer = io.BytesIO(fileclient.get(filename))\n        checkpoint = torch.load(buffer, map_location=map_location)\n    if world_size > 1:\n        torch.distributed.barrier()\n        if rank > 0:\n            fileclient = FileClient(backend=backend)\n            buffer = io.BytesIO(fileclient.get(filename))\n            checkpoint = torch.load(buffer, map_location=map_location)\n    return checkpoint\n\n\ndef get_torchvision_models():\n    model_urls = dict()\n    for _, name, ispkg in pkgutil.walk_packages(torchvision.models.__path__):\n        if ispkg:\n            continue\n        _zoo = import_module(f'torchvision.models.{name}')\n        if hasattr(_zoo, 'model_urls'):\n            _urls = getattr(_zoo, 'model_urls')\n            model_urls.update(_urls)\n    return model_urls\n\n\ndef get_external_models():\n    mmcv_home = _get_mmcv_home()\n    default_json_path = osp.join(mmcv.__path__[0], 'model_zoo/open_mmlab.json')\n    default_urls = load_file(default_json_path)\n    assert isinstance(default_urls, dict)\n    external_json_path = osp.join(mmcv_home, 'open_mmlab.json')\n    if osp.exists(external_json_path):\n        external_urls = load_file(external_json_path)\n        assert isinstance(external_urls, dict)\n        default_urls.update(external_urls)\n\n    return default_urls\n\n\ndef get_mmcls_models():\n    mmcls_json_path = osp.join(mmcv.__path__[0], 'model_zoo/mmcls.json')\n    mmcls_urls = load_file(mmcls_json_path)\n\n    return mmcls_urls\n\n\ndef get_deprecated_model_names():\n    deprecate_json_path = osp.join(mmcv.__path__[0],\n                                   'model_zoo/deprecated.json')\n    deprecate_urls = load_file(deprecate_json_path)\n    assert isinstance(deprecate_urls, dict)\n\n    return deprecate_urls\n\n\ndef _process_mmcls_checkpoint(checkpoint):\n    state_dict = checkpoint['state_dict']\n    new_state_dict = OrderedDict()\n    for k, v in state_dict.items():\n        if k.startswith('backbone.'):\n            new_state_dict[k[9:]] = v\n    new_checkpoint = dict(state_dict=new_state_dict)\n\n    return new_checkpoint\n\n\ndef _load_checkpoint(filename, map_location=None):\n    \"\"\"Load checkpoint from somewhere (modelzoo, file, url).\n\n    Args:\n        filename (str): Accept local filepath, URL, ``torchvision://xxx``,\n            ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for\n            details.\n        map_location (str | None): Same as :func:`torch.load`. Default: None.\n\n    Returns:\n        dict | OrderedDict: The loaded checkpoint. It can be either an\n            OrderedDict storing model weights or a dict containing other\n            information, which depends on the checkpoint.\n    \"\"\"\n    if filename.startswith('modelzoo://'):\n        warnings.warn('The URL scheme of \"modelzoo://\" is deprecated, please '\n                      'use \"torchvision://\" instead')\n        model_urls = get_torchvision_models()\n        model_name = filename[11:]\n        checkpoint = load_url_dist(model_urls[model_name])\n    elif filename.startswith('torchvision://'):\n        model_urls = get_torchvision_models()\n        model_name = filename[14:]\n        checkpoint = load_url_dist(model_urls[model_name])\n    elif filename.startswith('open-mmlab://'):\n        model_urls = get_external_models()\n        model_name = filename[13:]\n        deprecated_urls = get_deprecated_model_names()\n        if model_name in deprecated_urls:\n            warnings.warn(f'open-mmlab://{model_name} is deprecated in favor '\n                          f'of open-mmlab://{deprecated_urls[model_name]}')\n            model_name = deprecated_urls[model_name]\n        model_url = model_urls[model_name]\n        # check if is url\n        if model_url.startswith(('http://', 'https://')):\n            checkpoint = load_url_dist(model_url)\n        else:\n            filename = osp.join(_get_mmcv_home(), model_url)\n            if not osp.isfile(filename):\n                raise IOError(f'{filename} is not a checkpoint file')\n            checkpoint = torch.load(filename, map_location=map_location)\n    elif filename.startswith('mmcls://'):\n        model_urls = get_mmcls_models()\n        model_name = filename[8:]\n        checkpoint = load_url_dist(model_urls[model_name])\n        checkpoint = _process_mmcls_checkpoint(checkpoint)\n    elif filename.startswith(('http://', 'https://')):\n        checkpoint = load_url_dist(filename)\n    elif filename.startswith('pavi://'):\n        model_path = filename[7:]\n        checkpoint = load_pavimodel_dist(model_path, map_location=map_location)\n    elif filename.startswith('s3://'):\n        checkpoint = load_fileclient_dist(\n            filename, backend='ceph', map_location=map_location)\n    else:\n        if not osp.isfile(filename):\n            raise IOError(f'{filename} is not a checkpoint file')\n        checkpoint = torch.load(filename, map_location=map_location)\n    return checkpoint\n\n\ndef cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0,\n                     start_warmup_value=0, warmup_steps=-1):\n    warmup_schedule = np.array([])\n    warmup_iters = warmup_epochs * niter_per_ep\n    if warmup_steps > 0:\n        warmup_iters = warmup_steps\n    print(\"Set warmup steps = %d\" % warmup_iters)\n    if warmup_epochs > 0:\n        warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)\n\n    iters = np.arange(epochs * niter_per_ep - warmup_iters)\n    schedule = np.array(\n        [final_value + 0.5 * (base_value - final_value) * (1 + math.cos(math.pi * i / (len(iters)))) for i in iters])\n\n    schedule = np.concatenate((warmup_schedule, schedule))\n\n    assert len(schedule) == epochs * niter_per_ep\n    return schedule\n\n\ndef load_checkpoint(model,\n                    filename,\n                    map_location='cpu',\n                    strict=False,\n                    logger=None):\n    \"\"\"Load checkpoint from a file or URI.\n\n    Args:\n        model (Module): Module to load checkpoint.\n        filename (str): Accept local filepath, URL, ``torchvision://xxx``,\n            ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for\n            details.\n        map_location (str): Same as :func:`torch.load`.\n        strict (bool): Whether to allow different params for the model and\n            checkpoint.\n        logger (:mod:`logging.Logger` or None): The logger for error message.\n\n    Returns:\n        dict or OrderedDict: The loaded checkpoint.\n    \"\"\"\n    checkpoint = _load_checkpoint(filename, map_location)\n    # OrderedDict is a subclass of dict\n    if not isinstance(checkpoint, dict):\n        raise RuntimeError(\n            f'No state_dict found in checkpoint file {filename}')\n    # get state_dict from checkpoint\n    if 'state_dict' in checkpoint:\n        state_dict = checkpoint['state_dict']\n    elif 'model' in checkpoint:\n        state_dict = checkpoint['model']\n    elif 'module' in checkpoint:\n        state_dict = checkpoint['module']\n    else:\n        state_dict = checkpoint\n    # strip prefix of state_dict\n    if list(state_dict.keys())[0].startswith('module.'):\n        state_dict = {k[7:]: v for k, v in state_dict.items()}\n\n    # for MoBY, load model of online branch\n    if sorted(list(state_dict.keys()))[0].startswith('encoder'):\n        state_dict = {k.replace('encoder.', ''): v for k, v in state_dict.items() if k.startswith('encoder.')}\n\n    all_keys = list(state_dict.keys())\n    print(\"origin keys:\", len(all_keys), all_keys)\n    if all_keys[-1].startswith('encoder_to_decoder') or all_keys[-1].startswith('decoder'):\n        # NOTE: remove all decoder keys\n        all_keys = [key for key in all_keys if key.startswith('encoder.')]\n        print(\"all keys:\", all_keys)\n        for key in all_keys:\n            new_key = key.replace('encoder.','')\n            # print(\"new_key:\", new_key)\n            state_dict[new_key] = state_dict[key]\n            state_dict.pop(key)\n            \n        for key in list(state_dict.keys()):\n            if key.startswith('decoder.'):\n                # print(\"key:\", key)\n                state_dict.pop(key)\n\n        # NOTE: replace norm with fc_norm\n        for key in list(state_dict.keys()):\n            # print(\"new key:\", key)\n            if key.startswith('norm.'):\n                new_key = key.replace('norm.','fc_norm.')\n                state_dict[new_key] = state_dict[key]\n                state_dict.pop(key)\n    \n    print(\"new keys:\", len(state_dict), state_dict.keys())\n    \n    # reshape absolute position embedding for Swin\n    if state_dict.get('absolute_pos_embed') is not None:\n        absolute_pos_embed = state_dict['absolute_pos_embed']\n        N1, L, C1 = absolute_pos_embed.size()\n        N2, C2, H, W = model.absolute_pos_embed.size()\n        if N1 != N2 or C1 != C2 or L != H*W:\n            logger.warning(\"Error in loading absolute_pos_embed, pass\")\n        else:\n            state_dict['absolute_pos_embed'] = absolute_pos_embed.view(N2, H, W, C2).permute(0, 3, 1, 2)\n\n    rank, _ = get_dist_info()\n    if \"rel_pos_bias.relative_position_bias_table\" in state_dict:\n        if rank == 0:\n            print(\"Expand the shared relative position embedding to each layers. \")\n            num_layers = model.get_num_layers()\n            rel_pos_bias = state_dict[\"rel_pos_bias.relative_position_bias_table\"]\n            for i in range(num_layers):\n                state_dict[\"blocks.%d.attn.relative_position_bias_table\" % i] = rel_pos_bias.clone()\n\n        state_dict.pop(\"rel_pos_bias.relative_position_bias_table\")\n\n    all_keys = list(state_dict.keys())\n\n    # for moco\n    for key in all_keys:\n        if 'base_encoder.' in key:\n            new_key = key.replace('base_encoder.', '')\n            state_dict[new_key] = state_dict[key]\n            state_dict.pop(key)\n        if 'momentum_encoder' in key:\n            state_dict.pop(key)\n\n    # for ibot\n    for key in all_keys:\n        if 'module.backbone.' in key:\n            new_key = key.replace('module.backbone.', '')\n            state_dict[new_key] = state_dict[key]\n            state_dict.pop(key)\n        elif 'backbone.' in key:\n            new_key = key.replace('backbone.', '')\n            state_dict[new_key] = state_dict[key]\n            state_dict.pop(key)\n\n    for key in all_keys:\n        if \"relative_position_index\" in key:\n            state_dict.pop(key)\n\n        if \"relative_position_bias_table\" in key:\n            rel_pos_bias = state_dict[key]\n            src_num_pos, num_attn_heads = rel_pos_bias.size()\n            dst_num_pos, _ = model.state_dict()[key].size()\n            dst_patch_shape = model.patch_embed.patch_shape\n            if dst_patch_shape[0] != dst_patch_shape[1]:\n                raise NotImplementedError()\n            num_extra_tokens = dst_num_pos - (dst_patch_shape[0] * 2 - 1) * (dst_patch_shape[1] * 2 - 1)\n            src_size = int((src_num_pos - num_extra_tokens) ** 0.5)\n            dst_size = int((dst_num_pos - num_extra_tokens) ** 0.5)\n            if src_size != dst_size:\n                if rank == 0:\n                    print(\"Position interpolate for %s from %dx%d to %dx%d\" % (\n                        key, src_size, src_size, dst_size, dst_size))\n                extra_tokens = rel_pos_bias[-num_extra_tokens:, :]\n                rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :]\n\n                def geometric_progression(a, r, n):\n                    return a * (1.0 - r ** n) / (1.0 - r)\n\n                left, right = 1.01, 1.5\n                while right - left > 1e-6:\n                    q = (left + right) / 2.0\n                    gp = geometric_progression(1, q, src_size // 2)\n                    if gp > dst_size // 2:\n                        right = q\n                    else:\n                        left = q\n\n                # if q > 1.13492:\n                #     q = 1.13492\n\n                dis = []\n                cur = 1\n                for i in range(src_size // 2):\n                    dis.append(cur)\n                    cur += q ** (i + 1)\n\n                r_ids = [-_ for _ in reversed(dis)]\n\n                x = r_ids + [0] + dis\n                y = r_ids + [0] + dis\n\n                t = dst_size // 2.0\n                dx = np.arange(-t, t + 0.1, 1.0)\n                dy = np.arange(-t, t + 0.1, 1.0)\n                if rank == 0:\n                    print(\"x = {}\".format(x))\n                    print(\"dx = {}\".format(dx))\n\n                all_rel_pos_bias = []\n\n                for i in range(num_attn_heads):\n                    z = rel_pos_bias[:, i].view(src_size, src_size).float().numpy()\n                    f = interpolate.interp2d(x, y, z, kind='cubic')\n                    all_rel_pos_bias.append(\n                        torch.Tensor(f(dx, dy)).contiguous().view(-1, 1).to(rel_pos_bias.device))\n\n                rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1)\n                new_rel_pos_bias = torch.cat((rel_pos_bias, extra_tokens), dim=0)\n                state_dict[key] = new_rel_pos_bias\n\n    if 'pos_embed' in state_dict:  #and model.use_abs_pos_emb:\n        pos_embed_checkpoint = state_dict['pos_embed']\n        embedding_size = pos_embed_checkpoint.shape[-1]\n        num_patches = model.patch_embed.num_patches\n        num_extra_tokens = model.pos_embed.shape[-2] - num_patches\n        # height (== width) for the checkpoint position embedding\n        orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)\n        # height (== width) for the new position embedding\n        new_size = int(num_patches ** 0.5)\n        # class_token and dist_token are kept unchanged\n        if orig_size != new_size:\n            if rank == 0:\n                print(\"Position interpolate from %dx%d to %dx%d\" % (orig_size, orig_size, new_size, new_size))\n            extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]\n            # only the position tokens are interpolated\n            pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]\n            pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)\n            pos_tokens = torch.nn.functional.interpolate(\n                pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)\n            pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)\n            new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)\n            state_dict['pos_embed'] = new_pos_embed\n\n    # interpolate position bias table if needed\n    relative_position_bias_table_keys = [k for k in state_dict.keys() if \"relative_position_bias_table\" in k]\n    for table_key in relative_position_bias_table_keys:\n        table_pretrained = state_dict[table_key]\n        table_current = model.state_dict()[table_key]\n        L1, nH1 = table_pretrained.size()\n        L2, nH2 = table_current.size()\n        if nH1 != nH2:\n            logger.warning(f\"Error in loading {table_key}, pass\")\n        else:\n            if L1 != L2:\n                S1 = int(L1 ** 0.5)\n                S2 = int(L2 ** 0.5)\n                table_pretrained_resized = F.interpolate(\n                     table_pretrained.permute(1, 0).view(1, nH1, S1, S1),\n                     size=(S2, S2), mode='bicubic')\n                state_dict[table_key] = table_pretrained_resized.view(nH2, L2).permute(1, 0)\n\n    # load state_dict\n    load_state_dict(model, state_dict, strict, logger)\n    return checkpoint\n\n\ndef weights_to_cpu(state_dict):\n    \"\"\"Copy a model state_dict to cpu.\n\n    Args:\n        state_dict (OrderedDict): Model weights on GPU.\n\n    Returns:\n        OrderedDict: Model weights on GPU.\n    \"\"\"\n    state_dict_cpu = OrderedDict()\n    for key, val in state_dict.items():\n        state_dict_cpu[key] = val.cpu()\n    return state_dict_cpu\n\n\ndef _save_to_state_dict(module, destination, prefix, keep_vars):\n    \"\"\"Saves module state to `destination` dictionary.\n\n    This method is modified from :meth:`torch.nn.Module._save_to_state_dict`.\n\n    Args:\n        module (nn.Module): The module to generate state_dict.\n        destination (dict): A dict where state will be stored.\n        prefix (str): The prefix for parameters and buffers used in this\n            module.\n    \"\"\"\n    for name, param in module._parameters.items():\n        if param is not None:\n            destination[prefix + name] = param if keep_vars else param.detach()\n    for name, buf in module._buffers.items():\n        # remove check of _non_persistent_buffers_set to allow nn.BatchNorm2d\n        if buf is not None:\n            destination[prefix + name] = buf if keep_vars else buf.detach()\n\n\ndef get_state_dict(module, destination=None, prefix='', keep_vars=False):\n    \"\"\"Returns a dictionary containing a whole state of the module.\n\n    Both parameters and persistent buffers (e.g. running averages) are\n    included. Keys are corresponding parameter and buffer names.\n\n    This method is modified from :meth:`torch.nn.Module.state_dict` to\n    recursively check parallel module in case that the model has a complicated\n    structure, e.g., nn.Module(nn.Module(DDP)).\n\n    Args:\n        module (nn.Module): The module to generate state_dict.\n        destination (OrderedDict): Returned dict for the state of the\n            module.\n        prefix (str): Prefix of the key.\n        keep_vars (bool): Whether to keep the variable property of the\n            parameters. Default: False.\n\n    Returns:\n        dict: A dictionary containing a whole state of the module.\n    \"\"\"\n    # recursively check parallel module in case that the model has a\n    # complicated structure, e.g., nn.Module(nn.Module(DDP))\n    if is_module_wrapper(module):\n        module = module.module\n\n    # below is the same as torch.nn.Module.state_dict()\n    if destination is None:\n        destination = OrderedDict()\n        destination._metadata = OrderedDict()\n    destination._metadata[prefix[:-1]] = local_metadata = dict(\n        version=module._version)\n    _save_to_state_dict(module, destination, prefix, keep_vars)\n    for name, child in module._modules.items():\n        if child is not None:\n            get_state_dict(\n                child, destination, prefix + name + '.', keep_vars=keep_vars)\n    for hook in module._state_dict_hooks.values():\n        hook_result = hook(module, destination, prefix, local_metadata)\n        if hook_result is not None:\n            destination = hook_result\n    return destination\n\n\ndef save_checkpoint(model, filename, optimizer=None, meta=None):\n    \"\"\"Save checkpoint to file.\n\n    The checkpoint will have 3 fields: ``meta``, ``state_dict`` and\n    ``optimizer``. By default ``meta`` will contain version and time info.\n\n    Args:\n        model (Module): Module whose params are to be saved.\n        filename (str): Checkpoint filename.\n        optimizer (:obj:`Optimizer`, optional): Optimizer to be saved.\n        meta (dict, optional): Metadata to be saved in checkpoint.\n    \"\"\"\n    if meta is None:\n        meta = {}\n    elif not isinstance(meta, dict):\n        raise TypeError(f'meta must be a dict or None, but got {type(meta)}')\n    meta.update(mmcv_version=mmcv.__version__, time=time.asctime())\n\n    if is_module_wrapper(model):\n        model = model.module\n\n    if hasattr(model, 'CLASSES') and model.CLASSES is not None:\n        # save class name to the meta\n        meta.update(CLASSES=model.CLASSES)\n\n    checkpoint = {\n        'meta': meta,\n        'state_dict': weights_to_cpu(get_state_dict(model))\n    }\n    # save optimizer state dict in the checkpoint\n    if isinstance(optimizer, Optimizer):\n        checkpoint['optimizer'] = optimizer.state_dict()\n    elif isinstance(optimizer, dict):\n        checkpoint['optimizer'] = {}\n        for name, optim in optimizer.items():\n            checkpoint['optimizer'][name] = optim.state_dict()\n\n    if filename.startswith('pavi://'):\n        try:\n            from pavi import modelcloud\n            from pavi.exception import NodeNotFoundError\n        except ImportError:\n            raise ImportError(\n                'Please install pavi to load checkpoint from modelcloud.')\n        model_path = filename[7:]\n        root = modelcloud.Folder()\n        model_dir, model_name = osp.split(model_path)\n        try:\n            model = modelcloud.get(model_dir)\n        except NodeNotFoundError:\n            model = root.create_training_model(model_dir)\n        with TemporaryDirectory() as tmp_dir:\n            checkpoint_file = osp.join(tmp_dir, model_name)\n            with open(checkpoint_file, 'wb') as f:\n                torch.save(checkpoint, f)\n                f.flush()\n            model.create_file(checkpoint_file, name=model_name)\n    else:\n        mmcv.mkdir_or_exist(osp.dirname(filename))\n        # immediately flush buffer\n        with open(filename, 'wb') as f:\n            torch.save(checkpoint, f)\n            f.flush()\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/mmcv_custom/checkpoint_beit.py",
    "content": "# Copyright (c) Open-MMLab. All rights reserved.\nimport io\nimport os\nimport os.path as osp\nimport pkgutil\nimport time\nimport warnings\nfrom collections import OrderedDict\nfrom importlib import import_module\nfrom tempfile import TemporaryDirectory\n\nimport torch\nimport torchvision\nfrom torch.optim import Optimizer\nfrom torch.utils import model_zoo\nfrom torch.nn import functional as F\n\nimport mmcv\nfrom mmcv.fileio import FileClient\nfrom mmcv.fileio import load as load_file\nfrom mmcv.parallel import is_module_wrapper\nfrom mmcv.utils import mkdir_or_exist\nfrom mmcv.runner import get_dist_info\n\nfrom scipy import interpolate\nimport numpy as np\nimport math\n\nENV_MMCV_HOME = 'MMCV_HOME'\nENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'\nDEFAULT_CACHE_DIR = '~/.cache'\n\n\ndef _get_mmcv_home():\n    mmcv_home = os.path.expanduser(\n        os.getenv(\n            ENV_MMCV_HOME,\n            os.path.join(\n                os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'mmcv')))\n\n    mkdir_or_exist(mmcv_home)\n    return mmcv_home\n\n\ndef load_state_dict(module, state_dict, strict=False, logger=None):\n    \"\"\"Load state_dict to a module.\n\n    This method is modified from :meth:`torch.nn.Module.load_state_dict`.\n    Default value for ``strict`` is set to ``False`` and the message for\n    param mismatch will be shown even if strict is False.\n\n    Args:\n        module (Module): Module that receives the state_dict.\n        state_dict (OrderedDict): Weights.\n        strict (bool): whether to strictly enforce that the keys\n            in :attr:`state_dict` match the keys returned by this module's\n            :meth:`~torch.nn.Module.state_dict` function. Default: ``False``.\n        logger (:obj:`logging.Logger`, optional): Logger to log the error\n            message. If not specified, print function will be used.\n    \"\"\"\n    unexpected_keys = []\n    all_missing_keys = []\n    err_msg = []\n\n    metadata = getattr(state_dict, '_metadata', None)\n    state_dict = state_dict.copy()\n    if metadata is not None:\n        state_dict._metadata = metadata\n\n    # use _load_from_state_dict to enable checkpoint version control\n    def load(module, prefix=''):\n        # recursively check parallel module in case that the model has a\n        # complicated structure, e.g., nn.Module(nn.Module(DDP))\n        if is_module_wrapper(module):\n            module = module.module\n        local_metadata = {} if metadata is None else metadata.get(\n            prefix[:-1], {})\n        module._load_from_state_dict(state_dict, prefix, local_metadata, True,\n                                     all_missing_keys, unexpected_keys,\n                                     err_msg)\n        for name, child in module._modules.items():\n            if child is not None:\n                load(child, prefix + name + '.')\n\n    load(module)\n    load = None  # break load->load reference cycle\n\n    # ignore \"num_batches_tracked\" of BN layers\n    missing_keys = [\n        key for key in all_missing_keys if 'num_batches_tracked' not in key\n    ]\n\n    if unexpected_keys:\n        err_msg.append('unexpected key in source '\n                       f'state_dict: {\", \".join(unexpected_keys)}\\n')\n    if missing_keys:\n        err_msg.append(\n            f'missing keys in source state_dict: {\", \".join(missing_keys)}\\n')\n\n    rank, _ = get_dist_info()\n    if len(err_msg) > 0 and rank == 0:\n        err_msg.insert(\n            0, 'The model and loaded state dict do not match exactly\\n')\n        err_msg = '\\n'.join(err_msg)\n        if strict:\n            raise RuntimeError(err_msg)\n        elif logger is not None:\n            logger.warning(err_msg)\n        else:\n            print(err_msg)\n\n\ndef load_url_dist(url, model_dir=None, map_location=\"cpu\"):\n    \"\"\"In distributed setting, this function only download checkpoint at local\n    rank 0.\"\"\"\n    rank, world_size = get_dist_info()\n    rank = int(os.environ.get('LOCAL_RANK', rank))\n    if rank == 0:\n        checkpoint = model_zoo.load_url(url, model_dir=model_dir, map_location=map_location)\n    if world_size > 1:\n        torch.distributed.barrier()\n        if rank > 0:\n            checkpoint = model_zoo.load_url(url, model_dir=model_dir, map_location=map_location)\n    return checkpoint\n\n\ndef load_pavimodel_dist(model_path, map_location=None):\n    \"\"\"In distributed setting, this function only download checkpoint at local\n    rank 0.\"\"\"\n    try:\n        from pavi import modelcloud\n    except ImportError:\n        raise ImportError(\n            'Please install pavi to load checkpoint from modelcloud.')\n    rank, world_size = get_dist_info()\n    rank = int(os.environ.get('LOCAL_RANK', rank))\n    if rank == 0:\n        model = modelcloud.get(model_path)\n        with TemporaryDirectory() as tmp_dir:\n            downloaded_file = osp.join(tmp_dir, model.name)\n            model.download(downloaded_file)\n            checkpoint = torch.load(downloaded_file, map_location=map_location)\n    if world_size > 1:\n        torch.distributed.barrier()\n        if rank > 0:\n            model = modelcloud.get(model_path)\n            with TemporaryDirectory() as tmp_dir:\n                downloaded_file = osp.join(tmp_dir, model.name)\n                model.download(downloaded_file)\n                checkpoint = torch.load(\n                    downloaded_file, map_location=map_location)\n    return checkpoint\n\n\ndef load_fileclient_dist(filename, backend, map_location):\n    \"\"\"In distributed setting, this function only download checkpoint at local\n    rank 0.\"\"\"\n    rank, world_size = get_dist_info()\n    rank = int(os.environ.get('LOCAL_RANK', rank))\n    allowed_backends = ['ceph']\n    if backend not in allowed_backends:\n        raise ValueError(f'Load from Backend {backend} is not supported.')\n    if rank == 0:\n        fileclient = FileClient(backend=backend)\n        buffer = io.BytesIO(fileclient.get(filename))\n        checkpoint = torch.load(buffer, map_location=map_location)\n    if world_size > 1:\n        torch.distributed.barrier()\n        if rank > 0:\n            fileclient = FileClient(backend=backend)\n            buffer = io.BytesIO(fileclient.get(filename))\n            checkpoint = torch.load(buffer, map_location=map_location)\n    return checkpoint\n\n\ndef get_torchvision_models():\n    model_urls = dict()\n    for _, name, ispkg in pkgutil.walk_packages(torchvision.models.__path__):\n        if ispkg:\n            continue\n        _zoo = import_module(f'torchvision.models.{name}')\n        if hasattr(_zoo, 'model_urls'):\n            _urls = getattr(_zoo, 'model_urls')\n            model_urls.update(_urls)\n    return model_urls\n\n\ndef get_external_models():\n    mmcv_home = _get_mmcv_home()\n    default_json_path = osp.join(mmcv.__path__[0], 'model_zoo/open_mmlab.json')\n    default_urls = load_file(default_json_path)\n    assert isinstance(default_urls, dict)\n    external_json_path = osp.join(mmcv_home, 'open_mmlab.json')\n    if osp.exists(external_json_path):\n        external_urls = load_file(external_json_path)\n        assert isinstance(external_urls, dict)\n        default_urls.update(external_urls)\n\n    return default_urls\n\n\ndef get_mmcls_models():\n    mmcls_json_path = osp.join(mmcv.__path__[0], 'model_zoo/mmcls.json')\n    mmcls_urls = load_file(mmcls_json_path)\n\n    return mmcls_urls\n\n\ndef get_deprecated_model_names():\n    deprecate_json_path = osp.join(mmcv.__path__[0],\n                                   'model_zoo/deprecated.json')\n    deprecate_urls = load_file(deprecate_json_path)\n    assert isinstance(deprecate_urls, dict)\n\n    return deprecate_urls\n\n\ndef _process_mmcls_checkpoint(checkpoint):\n    state_dict = checkpoint['state_dict']\n    new_state_dict = OrderedDict()\n    for k, v in state_dict.items():\n        if k.startswith('backbone.'):\n            new_state_dict[k[9:]] = v\n    new_checkpoint = dict(state_dict=new_state_dict)\n\n    return new_checkpoint\n\n\ndef _load_checkpoint(filename, map_location=None):\n    \"\"\"Load checkpoint from somewhere (modelzoo, file, url).\n\n    Args:\n        filename (str): Accept local filepath, URL, ``torchvision://xxx``,\n            ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for\n            details.\n        map_location (str | None): Same as :func:`torch.load`. Default: None.\n\n    Returns:\n        dict | OrderedDict: The loaded checkpoint. It can be either an\n            OrderedDict storing model weights or a dict containing other\n            information, which depends on the checkpoint.\n    \"\"\"\n    if filename.startswith('modelzoo://'):\n        warnings.warn('The URL scheme of \"modelzoo://\" is deprecated, please '\n                      'use \"torchvision://\" instead')\n        model_urls = get_torchvision_models()\n        model_name = filename[11:]\n        checkpoint = load_url_dist(model_urls[model_name])\n    elif filename.startswith('torchvision://'):\n        model_urls = get_torchvision_models()\n        model_name = filename[14:]\n        checkpoint = load_url_dist(model_urls[model_name])\n    elif filename.startswith('open-mmlab://'):\n        model_urls = get_external_models()\n        model_name = filename[13:]\n        deprecated_urls = get_deprecated_model_names()\n        if model_name in deprecated_urls:\n            warnings.warn(f'open-mmlab://{model_name} is deprecated in favor '\n                          f'of open-mmlab://{deprecated_urls[model_name]}')\n            model_name = deprecated_urls[model_name]\n        model_url = model_urls[model_name]\n        # check if is url\n        if model_url.startswith(('http://', 'https://')):\n            checkpoint = load_url_dist(model_url)\n        else:\n            filename = osp.join(_get_mmcv_home(), model_url)\n            if not osp.isfile(filename):\n                raise IOError(f'{filename} is not a checkpoint file')\n            checkpoint = torch.load(filename, map_location=map_location)\n    elif filename.startswith('mmcls://'):\n        model_urls = get_mmcls_models()\n        model_name = filename[8:]\n        checkpoint = load_url_dist(model_urls[model_name])\n        checkpoint = _process_mmcls_checkpoint(checkpoint)\n    elif filename.startswith(('http://', 'https://')):\n        checkpoint = load_url_dist(filename)\n    elif filename.startswith('pavi://'):\n        model_path = filename[7:]\n        checkpoint = load_pavimodel_dist(model_path, map_location=map_location)\n    elif filename.startswith('s3://'):\n        checkpoint = load_fileclient_dist(\n            filename, backend='ceph', map_location=map_location)\n    else:\n        if not osp.isfile(filename):\n            raise IOError(f'{filename} is not a checkpoint file')\n        checkpoint = torch.load(filename, map_location=map_location)\n    return checkpoint\n\n\ndef cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0,\n                     start_warmup_value=0, warmup_steps=-1):\n    warmup_schedule = np.array([])\n    warmup_iters = warmup_epochs * niter_per_ep\n    if warmup_steps > 0:\n        warmup_iters = warmup_steps\n    print(\"Set warmup steps = %d\" % warmup_iters)\n    if warmup_epochs > 0:\n        warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)\n\n    iters = np.arange(epochs * niter_per_ep - warmup_iters)\n    schedule = np.array(\n        [final_value + 0.5 * (base_value - final_value) * (1 + math.cos(math.pi * i / (len(iters)))) for i in iters])\n\n    schedule = np.concatenate((warmup_schedule, schedule))\n\n    assert len(schedule) == epochs * niter_per_ep\n    return schedule\n\n\ndef load_checkpoint(model,\n                    filename,\n                    map_location='cpu',\n                    strict=False,\n                    logger=None):\n    \"\"\"Load checkpoint from a file or URI.\n\n    Args:\n        model (Module): Module to load checkpoint.\n        filename (str): Accept local filepath, URL, ``torchvision://xxx``,\n            ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for\n            details.\n        map_location (str): Same as :func:`torch.load`.\n        strict (bool): Whether to allow different params for the model and\n            checkpoint.\n        logger (:mod:`logging.Logger` or None): The logger for error message.\n\n    Returns:\n        dict or OrderedDict: The loaded checkpoint.\n    \"\"\"\n    checkpoint = _load_checkpoint(filename, map_location)\n    # OrderedDict is a subclass of dict\n    if not isinstance(checkpoint, dict):\n        raise RuntimeError(\n            f'No state_dict found in checkpoint file {filename}')\n    # get state_dict from checkpoint\n    if 'state_dict' in checkpoint:\n        state_dict = checkpoint['state_dict']\n    elif 'model' in checkpoint:\n        state_dict = checkpoint['model']\n    elif 'module' in checkpoint:\n        state_dict = checkpoint['module']\n    else:\n        state_dict = checkpoint\n    # strip prefix of state_dict\n    if list(state_dict.keys())[0].startswith('module.'):\n        state_dict = {k[7:]: v for k, v in state_dict.items()}\n\n    # for MoBY, load model of online branch\n    if sorted(list(state_dict.keys()))[0].startswith('encoder'):\n        state_dict = {k.replace('encoder.', ''): v for k, v in state_dict.items() if k.startswith('encoder.')}\n\n    # reshape absolute position embedding for Swin\n    if state_dict.get('absolute_pos_embed') is not None:\n        absolute_pos_embed = state_dict['absolute_pos_embed']\n        N1, L, C1 = absolute_pos_embed.size()\n        N2, C2, H, W = model.absolute_pos_embed.size()\n        if N1 != N2 or C1 != C2 or L != H*W:\n            logger.warning(\"Error in loading absolute_pos_embed, pass\")\n        else:\n            state_dict['absolute_pos_embed'] = absolute_pos_embed.view(N2, H, W, C2).permute(0, 3, 1, 2)\n\n    rank, _ = get_dist_info()\n    if \"rel_pos_bias.relative_position_bias_table\" in state_dict:\n        if rank == 0:\n            print(\"Expand the shared relative position embedding to each layers. \")\n            num_layers = model.get_num_layers()\n            rel_pos_bias = state_dict[\"rel_pos_bias.relative_position_bias_table\"]\n            for i in range(num_layers):\n                state_dict[\"blocks.%d.attn.relative_position_bias_table\" % i] = rel_pos_bias.clone()\n\n        state_dict.pop(\"rel_pos_bias.relative_position_bias_table\")\n\n    all_keys = list(state_dict.keys())\n    all_keys = sorted(all_keys)\n    print(\"origin keys:\", len(all_keys), all_keys)\n\n    if all_keys[-2].startswith('encoder_to_decoder'):\n        # NOTE: remove all decoder keys\n        all_keys = [key for key in all_keys if key.startswith('encoder.')]\n        print(\"all keys:\", all_keys)\n        for key in all_keys:\n            new_key = key.replace('encoder.','')\n            # print(\"new_key:\", new_key)\n            state_dict[new_key] = state_dict[key]\n            state_dict.pop(key)\n            \n        for key in list(state_dict.keys()):\n            if key.startswith('decoder.'):\n                # print(\"key:\", key)\n                state_dict.pop(key)\n\n        # NOTE: replace norm with fc_norm\n        for key in list(state_dict.keys()):\n            # print(\"new key:\", key)\n            if key.startswith('norm.'):\n                new_key = key.replace('norm.','fc_norm.')\n                state_dict[new_key] = state_dict[key]\n                state_dict.pop(key)\n    \n    print(\"new keys:\", len(state_dict), state_dict.keys())\n    \n\n    for key in all_keys:\n        if \"relative_position_index\" in key:\n            state_dict.pop(key)\n\n        if \"relative_position_bias_table\" in key:\n            rel_pos_bias = state_dict[key]\n            src_num_pos, num_attn_heads = rel_pos_bias.size()\n            dst_num_pos, _ = model.state_dict()[key].size()\n            dst_patch_shape = model.patch_embed.patch_shape\n            if dst_patch_shape[0] != dst_patch_shape[1]:\n                raise NotImplementedError()\n            num_extra_tokens = dst_num_pos - (dst_patch_shape[0] * 2 - 1) * (dst_patch_shape[1] * 2 - 1)\n            src_size = int((src_num_pos - num_extra_tokens) ** 0.5)\n            dst_size = int((dst_num_pos - num_extra_tokens) ** 0.5)\n            if src_size != dst_size:\n                if rank == 0:\n                    print(\"Position interpolate for %s from %dx%d to %dx%d\" % (\n                        key, src_size, src_size, dst_size, dst_size))\n                extra_tokens = rel_pos_bias[-num_extra_tokens:, :]\n                rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :]\n\n                def geometric_progression(a, r, n):\n                    return a * (1.0 - r ** n) / (1.0 - r)\n\n                left, right = 1.01, 1.5\n                while right - left > 1e-6:\n                    q = (left + right) / 2.0\n                    gp = geometric_progression(1, q, src_size // 2)\n                    if gp > dst_size // 2:\n                        right = q\n                    else:\n                        left = q\n\n                # if q > 1.13492:\n                #     q = 1.13492\n\n                dis = []\n                cur = 1\n                for i in range(src_size // 2):\n                    dis.append(cur)\n                    cur += q ** (i + 1)\n\n                r_ids = [-_ for _ in reversed(dis)]\n\n                x = r_ids + [0] + dis\n                y = r_ids + [0] + dis\n\n                t = dst_size // 2.0\n                dx = np.arange(-t, t + 0.1, 1.0)\n                dy = np.arange(-t, t + 0.1, 1.0)\n                if rank == 0:\n                    print(\"x = {}\".format(x))\n                    print(\"dx = {}\".format(dx))\n\n                all_rel_pos_bias = []\n\n                for i in range(num_attn_heads):\n                    z = rel_pos_bias[:, i].view(src_size, src_size).float().numpy()\n                    f = interpolate.interp2d(x, y, z, kind='cubic')\n                    all_rel_pos_bias.append(\n                        torch.Tensor(f(dx, dy)).contiguous().view(-1, 1).to(rel_pos_bias.device))\n\n                rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1)\n                new_rel_pos_bias = torch.cat((rel_pos_bias, extra_tokens), dim=0)\n                state_dict[key] = new_rel_pos_bias\n\n    if 'pos_embed' in state_dict:\n        pos_embed_checkpoint = state_dict['pos_embed']\n        embedding_size = pos_embed_checkpoint.shape[-1]\n        num_patches = model.patch_embed.num_patches\n        num_extra_tokens = model.pos_embed.shape[-2] - num_patches\n        # height (== width) for the checkpoint position embedding\n        orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)\n        # height (== width) for the new position embedding\n        new_size = int(num_patches ** 0.5)\n        # class_token and dist_token are kept unchanged\n        if orig_size != new_size:\n            if rank == 0:\n                print(\"Position interpolate from %dx%d to %dx%d\" % (orig_size, orig_size, new_size, new_size))\n            extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]\n            # only the position tokens are interpolated\n            pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]\n            pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)\n            pos_tokens = torch.nn.functional.interpolate(\n                pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)\n            pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)\n            new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)\n            state_dict['pos_embed'] = new_pos_embed\n\n    # interpolate position bias table if needed\n    relative_position_bias_table_keys = [k for k in state_dict.keys() if \"relative_position_bias_table\" in k]\n    for table_key in relative_position_bias_table_keys:\n        table_pretrained = state_dict[table_key]\n        table_current = model.state_dict()[table_key]\n        L1, nH1 = table_pretrained.size()\n        L2, nH2 = table_current.size()\n        if nH1 != nH2:\n            logger.warning(f\"Error in loading {table_key}, pass\")\n        else:\n            if L1 != L2:\n                S1 = int(L1 ** 0.5)\n                S2 = int(L2 ** 0.5)\n                table_pretrained_resized = F.interpolate(\n                     table_pretrained.permute(1, 0).view(1, nH1, S1, S1),\n                     size=(S2, S2), mode='bicubic')\n                state_dict[table_key] = table_pretrained_resized.view(nH2, L2).permute(1, 0)\n\n    # load state_dict\n    load_state_dict(model, state_dict, strict, logger)\n    return checkpoint\n\n\ndef weights_to_cpu(state_dict):\n    \"\"\"Copy a model state_dict to cpu.\n\n    Args:\n        state_dict (OrderedDict): Model weights on GPU.\n\n    Returns:\n        OrderedDict: Model weights on GPU.\n    \"\"\"\n    state_dict_cpu = OrderedDict()\n    for key, val in state_dict.items():\n        state_dict_cpu[key] = val.cpu()\n    return state_dict_cpu\n\n\ndef _save_to_state_dict(module, destination, prefix, keep_vars):\n    \"\"\"Saves module state to `destination` dictionary.\n\n    This method is modified from :meth:`torch.nn.Module._save_to_state_dict`.\n\n    Args:\n        module (nn.Module): The module to generate state_dict.\n        destination (dict): A dict where state will be stored.\n        prefix (str): The prefix for parameters and buffers used in this\n            module.\n    \"\"\"\n    for name, param in module._parameters.items():\n        if param is not None:\n            destination[prefix + name] = param if keep_vars else param.detach()\n    for name, buf in module._buffers.items():\n        # remove check of _non_persistent_buffers_set to allow nn.BatchNorm2d\n        if buf is not None:\n            destination[prefix + name] = buf if keep_vars else buf.detach()\n\n\ndef get_state_dict(module, destination=None, prefix='', keep_vars=False):\n    \"\"\"Returns a dictionary containing a whole state of the module.\n\n    Both parameters and persistent buffers (e.g. running averages) are\n    included. Keys are corresponding parameter and buffer names.\n\n    This method is modified from :meth:`torch.nn.Module.state_dict` to\n    recursively check parallel module in case that the model has a complicated\n    structure, e.g., nn.Module(nn.Module(DDP)).\n\n    Args:\n        module (nn.Module): The module to generate state_dict.\n        destination (OrderedDict): Returned dict for the state of the\n            module.\n        prefix (str): Prefix of the key.\n        keep_vars (bool): Whether to keep the variable property of the\n            parameters. Default: False.\n\n    Returns:\n        dict: A dictionary containing a whole state of the module.\n    \"\"\"\n    # recursively check parallel module in case that the model has a\n    # complicated structure, e.g., nn.Module(nn.Module(DDP))\n    if is_module_wrapper(module):\n        module = module.module\n\n    # below is the same as torch.nn.Module.state_dict()\n    if destination is None:\n        destination = OrderedDict()\n        destination._metadata = OrderedDict()\n    destination._metadata[prefix[:-1]] = local_metadata = dict(\n        version=module._version)\n    _save_to_state_dict(module, destination, prefix, keep_vars)\n    for name, child in module._modules.items():\n        if child is not None:\n            get_state_dict(\n                child, destination, prefix + name + '.', keep_vars=keep_vars)\n    for hook in module._state_dict_hooks.values():\n        hook_result = hook(module, destination, prefix, local_metadata)\n        if hook_result is not None:\n            destination = hook_result\n    return destination\n\n\ndef save_checkpoint(model, filename, optimizer=None, meta=None):\n    \"\"\"Save checkpoint to file.\n\n    The checkpoint will have 3 fields: ``meta``, ``state_dict`` and\n    ``optimizer``. By default ``meta`` will contain version and time info.\n\n    Args:\n        model (Module): Module whose params are to be saved.\n        filename (str): Checkpoint filename.\n        optimizer (:obj:`Optimizer`, optional): Optimizer to be saved.\n        meta (dict, optional): Metadata to be saved in checkpoint.\n    \"\"\"\n    if meta is None:\n        meta = {}\n    elif not isinstance(meta, dict):\n        raise TypeError(f'meta must be a dict or None, but got {type(meta)}')\n    meta.update(mmcv_version=mmcv.__version__, time=time.asctime())\n\n    if is_module_wrapper(model):\n        model = model.module\n\n    if hasattr(model, 'CLASSES') and model.CLASSES is not None:\n        # save class name to the meta\n        meta.update(CLASSES=model.CLASSES)\n\n    checkpoint = {\n        'meta': meta,\n        'state_dict': weights_to_cpu(get_state_dict(model))\n    }\n    # save optimizer state dict in the checkpoint\n    if isinstance(optimizer, Optimizer):\n        checkpoint['optimizer'] = optimizer.state_dict()\n    elif isinstance(optimizer, dict):\n        checkpoint['optimizer'] = {}\n        for name, optim in optimizer.items():\n            checkpoint['optimizer'][name] = optim.state_dict()\n\n    if filename.startswith('pavi://'):\n        try:\n            from pavi import modelcloud\n            from pavi.exception import NodeNotFoundError\n        except ImportError:\n            raise ImportError(\n                'Please install pavi to load checkpoint from modelcloud.')\n        model_path = filename[7:]\n        root = modelcloud.Folder()\n        model_dir, model_name = osp.split(model_path)\n        try:\n            model = modelcloud.get(model_dir)\n        except NodeNotFoundError:\n            model = root.create_training_model(model_dir)\n        with TemporaryDirectory() as tmp_dir:\n            checkpoint_file = osp.join(tmp_dir, model_name)\n            with open(checkpoint_file, 'wb') as f:\n                torch.save(checkpoint, f)\n                f.flush()\n            model.create_file(checkpoint_file, name=model_name)\n    else:\n        mmcv.mkdir_or_exist(osp.dirname(filename))\n        # immediately flush buffer\n        with open(filename, 'wb') as f:\n            torch.save(checkpoint, f)\n            f.flush()\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/mmcv_custom/layer_decay_optimizer_constructor.py",
    "content": "import json\nfrom mmcv.runner import OPTIMIZER_BUILDERS, DefaultOptimizerConstructor\nfrom mmcv.runner import get_dist_info\n\n\ndef get_num_layer_for_vit(var_name, num_max_layer):\n    if var_name in (\"backbone.cls_token\", \"backbone.mask_token\", \"backbone.pos_embed\"):\n        return 0\n    elif var_name.startswith(\"backbone.patch_embed\"):\n        return 0\n    elif var_name.startswith(\"backbone.blocks\"):\n        layer_id = int(var_name.split('.')[2])\n        return layer_id + 1\n    else:\n        return num_max_layer - 1\n\n\n@OPTIMIZER_BUILDERS.register_module()\nclass LayerDecayOptimizerConstructor(DefaultOptimizerConstructor):\n    def add_params(self, params, module, prefix='', is_dcn_module=None):\n        \"\"\"Add all parameters of module to the params list.\n        The parameters of the given module will be added to the list of param\n        groups, with specific rules defined by paramwise_cfg.\n        Args:\n            params (list[dict]): A list of param groups, it will be modified\n                in place.\n            module (nn.Module): The module to be added.\n            prefix (str): The prefix of the module\n            is_dcn_module (int|float|None): If the current module is a\n                submodule of DCN, `is_dcn_module` will be passed to\n                control conv_offset layer's learning rate. Defaults to None.\n        \"\"\"\n        parameter_groups = {}\n        print(self.paramwise_cfg)\n        num_layers = self.paramwise_cfg.get('num_layers') + 2\n        layer_decay_rate = self.paramwise_cfg.get('layer_decay_rate')\n        print(\"Build LayerDecayOptimizerConstructor %f - %d\" % (layer_decay_rate, num_layers))\n        weight_decay = self.base_wd\n\n        for name, param in module.named_parameters():\n            if not param.requires_grad:\n                continue  # frozen weights\n            if len(param.shape) == 1 or name.endswith(\".bias\") or name in ('pos_embed', 'cls_token'):\n                group_name = \"no_decay\"\n                this_weight_decay = 0.\n            else:\n                group_name = \"decay\"\n                this_weight_decay = weight_decay\n\n            layer_id = get_num_layer_for_vit(name, num_layers)\n            group_name = \"layer_%d_%s\" % (layer_id, group_name)\n\n            if group_name not in parameter_groups:\n                scale = layer_decay_rate ** (num_layers - layer_id - 1)\n\n                parameter_groups[group_name] = {\n                    \"weight_decay\": this_weight_decay,\n                    \"params\": [],\n                    \"param_names\": [], \n                    \"lr_scale\": scale, \n                    \"group_name\": group_name, \n                    \"lr\": scale * self.base_lr, \n                }\n\n            parameter_groups[group_name][\"params\"].append(param)\n            parameter_groups[group_name][\"param_names\"].append(name)\n        rank, _ = get_dist_info()\n        if rank == 0:\n            to_display = {}\n            for key in parameter_groups:\n                to_display[key] = {\n                    \"param_names\": parameter_groups[key][\"param_names\"], \n                    \"lr_scale\": parameter_groups[key][\"lr_scale\"], \n                    \"lr\": parameter_groups[key][\"lr\"], \n                    \"weight_decay\": parameter_groups[key][\"weight_decay\"], \n                }\n            print(\"Param groups = %s\" % json.dumps(to_display, indent=2))\n        \n        # state_dict = module.state_dict()\n        # for group_name in parameter_groups:\n        #     group = parameter_groups[group_name]\n        #     for name in group[\"param_names\"]:\n        #         group[\"params\"].append(state_dict[name])\n        params.extend(parameter_groups.values())\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/mmcv_custom/resize_transform.py",
    "content": "import mmcv\nimport numpy as np\n\nfrom mmseg.datasets.builder import PIPELINES\n\n\n@PIPELINES.register_module()\nclass SETR_Resize(object):\n    \"\"\"Resize images & seg.\n\n    This transform resizes the input image to some scale. If the input dict\n    contains the key \"scale\", then the scale in the input dict is used,\n    otherwise the specified scale in the init method is used.\n\n    ``img_scale`` can either be a tuple (single-scale) or a list of tuple\n    (multi-scale). There are 3 multiscale modes:\n\n    - ``ratio_range is not None``: randomly sample a ratio from the ratio range\n    and multiply it with the image scale.\n\n    - ``ratio_range is None and multiscale_mode == \"range\"``: randomly sample a\n    scale from the a range.\n\n    - ``ratio_range is None and multiscale_mode == \"value\"``: randomly sample a\n    scale from multiple scales.\n\n    Args:\n        img_scale (tuple or list[tuple]): Images scales for resizing.\n        multiscale_mode (str): Either \"range\" or \"value\".\n        ratio_range (tuple[float]): (min_ratio, max_ratio)\n        keep_ratio (bool): Whether to keep the aspect ratio when resizing the\n            image.\n    \"\"\"\n\n    def __init__(self,\n                 img_scale=None,\n                 multiscale_mode='range',\n                 ratio_range=None,\n                 keep_ratio=True,\n                 crop_size=None,\n                 setr_multi_scale=False):\n\n        if img_scale is None:\n            self.img_scale = None\n        else:\n            if isinstance(img_scale, list):\n                self.img_scale = img_scale\n            else:\n                self.img_scale = [img_scale]\n            # assert mmcv.is_list_of(self.img_scale, tuple)\n\n        if ratio_range is not None:\n            # mode 1: given a scale and a range of image ratio\n            assert len(self.img_scale) == 1\n        else:\n            # mode 2: given multiple scales or a range of scales\n            assert multiscale_mode in ['value', 'range']\n\n        self.multiscale_mode = multiscale_mode\n        self.ratio_range = ratio_range\n        self.keep_ratio = keep_ratio\n        self.crop_size = crop_size\n        self.setr_multi_scale = setr_multi_scale\n\n    @staticmethod\n    def random_select(img_scales):\n        \"\"\"Randomly select an img_scale from given candidates.\n\n        Args:\n            img_scales (list[tuple]): Images scales for selection.\n\n        Returns:\n            (tuple, int): Returns a tuple ``(img_scale, scale_dix)``,\n                where ``img_scale`` is the selected image scale and\n                ``scale_idx`` is the selected index in the given candidates.\n        \"\"\"\n\n        assert mmcv.is_list_of(img_scales, tuple)\n        scale_idx = np.random.randint(len(img_scales))\n        img_scale = img_scales[scale_idx]\n        return img_scale, scale_idx\n\n    @staticmethod\n    def random_sample(img_scales):\n        \"\"\"Randomly sample an img_scale when ``multiscale_mode=='range'``.\n\n        Args:\n            img_scales (list[tuple]): Images scale range for sampling.\n                There must be two tuples in img_scales, which specify the lower\n                and uper bound of image scales.\n\n        Returns:\n            (tuple, None): Returns a tuple ``(img_scale, None)``, where\n                ``img_scale`` is sampled scale and None is just a placeholder\n                to be consistent with :func:`random_select`.\n        \"\"\"\n\n        assert mmcv.is_list_of(img_scales, tuple) and len(img_scales) == 2\n        img_scale_long = [max(s) for s in img_scales]\n        img_scale_short = [min(s) for s in img_scales]\n        long_edge = np.random.randint(\n            min(img_scale_long),\n            max(img_scale_long) + 1)\n        short_edge = np.random.randint(\n            min(img_scale_short),\n            max(img_scale_short) + 1)\n        img_scale = (long_edge, short_edge)\n        return img_scale, None\n\n    @staticmethod\n    def random_sample_ratio(img_scale, ratio_range):\n        \"\"\"Randomly sample an img_scale when ``ratio_range`` is specified.\n\n        A ratio will be randomly sampled from the range specified by\n        ``ratio_range``. Then it would be multiplied with ``img_scale`` to\n        generate sampled scale.\n\n        Args:\n            img_scale (tuple): Images scale base to multiply with ratio.\n            ratio_range (tuple[float]): The minimum and maximum ratio to scale\n                the ``img_scale``.\n\n        Returns:\n            (tuple, None): Returns a tuple ``(scale, None)``, where\n                ``scale`` is sampled ratio multiplied with ``img_scale`` and\n                None is just a placeholder to be consistent with\n                :func:`random_select`.\n        \"\"\"\n\n        assert isinstance(img_scale, tuple) and len(img_scale) == 2\n        min_ratio, max_ratio = ratio_range\n        assert min_ratio <= max_ratio\n        ratio = np.random.random_sample() * (max_ratio - min_ratio) + min_ratio\n        scale = int(img_scale[0] * ratio), int(img_scale[1] * ratio)\n        return scale, None\n\n    def _random_scale(self, results):\n        \"\"\"Randomly sample an img_scale according to ``ratio_range`` and\n        ``multiscale_mode``.\n\n        If ``ratio_range`` is specified, a ratio will be sampled and be\n        multiplied with ``img_scale``.\n        If multiple scales are specified by ``img_scale``, a scale will be\n        sampled according to ``multiscale_mode``.\n        Otherwise, single scale will be used.\n\n        Args:\n            results (dict): Result dict from :obj:`dataset`.\n\n        Returns:\n            dict: Two new keys 'scale` and 'scale_idx` are added into\n                ``results``, which would be used by subsequent pipelines.\n        \"\"\"\n\n        if self.ratio_range is not None:\n            scale, scale_idx = self.random_sample_ratio(\n                self.img_scale[0], self.ratio_range)\n        elif len(self.img_scale) == 1:\n            scale, scale_idx = self.img_scale[0], 0\n        elif self.multiscale_mode == 'range':\n            scale, scale_idx = self.random_sample(self.img_scale)\n        elif self.multiscale_mode == 'value':\n            scale, scale_idx = self.random_select(self.img_scale)\n        else:\n            raise NotImplementedError\n\n        results['scale'] = scale\n        results['scale_idx'] = scale_idx\n\n    def _resize_img(self, results):\n        \"\"\"Resize images with ``results['scale']``.\"\"\"\n\n        if self.keep_ratio:\n            if self.setr_multi_scale:\n                if min(results['scale']) < self.crop_size[0]:\n                    new_short = self.crop_size[0]\n                else:\n                    new_short = min(results['scale'])\n                    \n                h, w = results['img'].shape[:2]\n                if h > w:\n                    new_h, new_w = new_short * h / w, new_short\n                else:\n                    new_h, new_w = new_short, new_short * w / h\n                results['scale'] = (new_h, new_w)\n\n            img, scale_factor = mmcv.imrescale(\n                results['img'], results['scale'], return_scale=True)\n            # the w_scale and h_scale has minor difference\n            # a real fix should be done in the mmcv.imrescale in the future\n            new_h, new_w = img.shape[:2]\n            h, w = results['img'].shape[:2]\n            w_scale = new_w / w\n            h_scale = new_h / h\n        else:\n            img, w_scale, h_scale = mmcv.imresize(\n                results['img'], results['scale'], return_scale=True)\n        scale_factor = np.array([w_scale, h_scale, w_scale, h_scale],\n                                dtype=np.float32)\n        results['img'] = img\n        results['img_shape'] = img.shape\n        results['pad_shape'] = img.shape  # in case that there is no padding\n        results['scale_factor'] = scale_factor\n        results['keep_ratio'] = self.keep_ratio\n\n    def _resize_seg(self, results):\n        \"\"\"Resize semantic segmentation map with ``results['scale']``.\"\"\"\n        for key in results.get('seg_fields', []):\n            if self.keep_ratio:\n                gt_seg = mmcv.imrescale(\n                    results[key], results['scale'], interpolation='nearest')\n            else:\n                gt_seg = mmcv.imresize(\n                    results[key], results['scale'], interpolation='nearest')\n            results['gt_semantic_seg'] = gt_seg\n\n    def __call__(self, results):\n        \"\"\"Call function to resize images, bounding boxes, masks, semantic\n        segmentation map.\n\n        Args:\n            results (dict): Result dict from loading pipeline.\n\n        Returns:\n            dict: Resized results, 'img_shape', 'pad_shape', 'scale_factor',\n                'keep_ratio' keys are added into result dict.\n        \"\"\"\n\n        if 'scale' not in results:\n            self._random_scale(results)\n        self._resize_img(results)\n        self._resize_seg(results)\n        return results\n\n    def __repr__(self):\n        repr_str = self.__class__.__name__\n        repr_str += (f'(img_scale={self.img_scale}, '\n                     f'multiscale_mode={self.multiscale_mode}, '\n                     f'ratio_range={self.ratio_range}, '\n                     f'keep_ratio={self.keep_ratio})')\n        return repr_str\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/mmcv_custom/train_api.py",
    "content": "import random\nimport warnings\n\nimport numpy as np\nimport torch\nfrom mmcv.parallel import MMDataParallel, MMDistributedDataParallel\nfrom mmcv.runner import build_optimizer, build_runner\n\nfrom mmseg.core import DistEvalHook, EvalHook\nfrom mmseg.datasets import build_dataloader, build_dataset\nfrom mmseg.utils import get_root_logger\ntry:\n    import apex\nexcept:\n    print('apex is not installed')\n\n\ndef set_random_seed(seed, deterministic=False):\n    \"\"\"Set random seed.\n\n    Args:\n        seed (int): Seed to be used.\n        deterministic (bool): Whether to set the deterministic option for\n            CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`\n            to True and `torch.backends.cudnn.benchmark` to False.\n            Default: False.\n    \"\"\"\n    random.seed(seed)\n    np.random.seed(seed)\n    torch.manual_seed(seed)\n    torch.cuda.manual_seed_all(seed)\n    if deterministic:\n        torch.backends.cudnn.deterministic = True\n        torch.backends.cudnn.benchmark = False\n\n\ndef train_segmentor(model,\n                    dataset,\n                    cfg,\n                    distributed=False,\n                    validate=False,\n                    timestamp=None,\n                    meta=None):\n    \"\"\"Launch segmentor training.\"\"\"\n    logger = get_root_logger(cfg.log_level)\n\n    # prepare data loaders\n    dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]\n    data_loaders = [\n        build_dataloader(\n            ds,\n            cfg.data.samples_per_gpu,\n            cfg.data.workers_per_gpu,\n            # cfg.gpus will be ignored if distributed\n            len(cfg.gpu_ids),\n            dist=distributed,\n            seed=cfg.seed,\n            drop_last=True) for ds in dataset\n    ]\n\n    # build optimizer\n    optimizer = build_optimizer(model, cfg.optimizer)\n\n    # use apex fp16 optimizer\n    if cfg.optimizer_config.get(\"type\", None) and cfg.optimizer_config[\"type\"] == \"DistOptimizerHook\":\n        if cfg.optimizer_config.get(\"use_fp16\", False):\n            model, optimizer = apex.amp.initialize(\n                model.cuda(), optimizer, opt_level=\"O1\")\n            for m in model.modules():\n                if hasattr(m, \"fp16_enabled\"):\n                    m.fp16_enabled = True\n\n    # put model on gpus\n    if distributed:\n        find_unused_parameters = cfg.get('find_unused_parameters', False)\n        # Sets the `find_unused_parameters` parameter in\n        # torch.nn.parallel.DistributedDataParallel\n        model = MMDistributedDataParallel(\n            model.cuda(),\n            device_ids=[torch.cuda.current_device()],\n            broadcast_buffers=False,\n            find_unused_parameters=find_unused_parameters)\n    else:\n        model = MMDataParallel(\n            model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids)\n\n    if cfg.get('runner') is None:\n        cfg.runner = {'type': 'IterBasedRunner', 'max_iters': cfg.total_iters}\n        warnings.warn(\n            'config is now expected to have a `runner` section, '\n            'please set `runner` in your config.', UserWarning)\n\n    runner = build_runner(\n        cfg.runner,\n        default_args=dict(\n            model=model,\n            batch_processor=None,\n            optimizer=optimizer,\n            work_dir=cfg.work_dir,\n            logger=logger,\n            meta=meta))\n\n    # register hooks\n    runner.register_training_hooks(cfg.lr_config, cfg.optimizer_config,\n                                   cfg.checkpoint_config, cfg.log_config,\n                                   cfg.get('momentum_config', None))\n\n    # an ugly walkaround to make the .log and .log.json filenames the same\n    runner.timestamp = timestamp\n\n    # register eval hooks\n    if validate:\n        val_dataset = build_dataset(cfg.data.val, dict(test_mode=True))\n        val_dataloader = build_dataloader(\n            val_dataset,\n            samples_per_gpu=1,\n            workers_per_gpu=cfg.data.workers_per_gpu,\n            dist=distributed,\n            shuffle=False)\n        eval_cfg = cfg.get('evaluation', {})\n        eval_cfg['by_epoch'] = 'IterBasedRunner' not in cfg.runner['type']\n        eval_hook = DistEvalHook if distributed else EvalHook\n        runner.register_hook(eval_hook(val_dataloader, **eval_cfg))\n\n    if cfg.resume_from:\n        runner.resume(cfg.resume_from)\n    elif cfg.load_from:\n        runner.load_checkpoint(cfg.load_from)\n    runner.run(data_loaders, cfg.workflow)\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/mmseg/__init__.py",
    "content": "import mmcv\n\nfrom .version import __version__, version_info\n\nMMCV_MIN = '1.1.4'\nMMCV_MAX = '1.3.0'\n\n\ndef digit_version(version_str):\n    digit_version = []\n    for x in version_str.split('.'):\n        if x.isdigit():\n            digit_version.append(int(x))\n        elif x.find('rc') != -1:\n            patch_version = x.split('rc')\n            digit_version.append(int(patch_version[0]) - 1)\n            digit_version.append(int(patch_version[1]))\n    return digit_version\n\n\nmmcv_min_version = digit_version(MMCV_MIN)\nmmcv_max_version = digit_version(MMCV_MAX)\nmmcv_version = digit_version(mmcv.__version__)\n\n\nassert (mmcv_min_version <= mmcv_version <= mmcv_max_version), \\\n    f'MMCV=={mmcv.__version__} is used but incompatible. ' \\\n    f'Please install mmcv>={mmcv_min_version}, <={mmcv_max_version}.'\n\n__all__ = ['__version__', 'version_info']\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/mmseg/apis/__init__.py",
    "content": "from .inference import inference_segmentor, init_segmentor, show_result_pyplot\nfrom .test import multi_gpu_test, single_gpu_test\nfrom .train import get_root_logger, set_random_seed, train_segmentor\n\n__all__ = [\n    'get_root_logger', 'set_random_seed', 'train_segmentor', 'init_segmentor',\n    'inference_segmentor', 'multi_gpu_test', 'single_gpu_test',\n    'show_result_pyplot'\n]\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/mmseg/apis/inference.py",
    "content": "import matplotlib.pyplot as plt\nimport mmcv\nimport torch\nfrom mmcv.parallel import collate, scatter\nfrom mmcv.runner import load_checkpoint\n\nfrom mmseg.datasets.pipelines import Compose\nfrom mmseg.models import build_segmentor\n\n\ndef init_segmentor(config, checkpoint=None, device='cuda:0'):\n    \"\"\"Initialize a segmentor from config file.\n\n    Args:\n        config (str or :obj:`mmcv.Config`): Config file path or the config\n            object.\n        checkpoint (str, optional): Checkpoint path. If left as None, the model\n            will not load any weights.\n        device (str, optional) CPU/CUDA device option. Default 'cuda:0'.\n            Use 'cpu' for loading model on CPU.\n    Returns:\n        nn.Module: The constructed segmentor.\n    \"\"\"\n    if isinstance(config, str):\n        config = mmcv.Config.fromfile(config)\n    elif not isinstance(config, mmcv.Config):\n        raise TypeError('config must be a filename or Config object, '\n                        'but got {}'.format(type(config)))\n    config.model.pretrained = None\n    config.model.train_cfg = None\n    model = build_segmentor(config.model, test_cfg=config.get('test_cfg'))\n    if checkpoint is not None:\n        checkpoint = load_checkpoint(model, checkpoint, map_location='cpu')\n        model.CLASSES = checkpoint['meta']['CLASSES']\n        model.PALETTE = checkpoint['meta']['PALETTE']\n    model.cfg = config  # save the config in the model for convenience\n    model.to(device)\n    model.eval()\n    return model\n\n\nclass LoadImage:\n    \"\"\"A simple pipeline to load image.\"\"\"\n\n    def __call__(self, results):\n        \"\"\"Call function to load images into results.\n\n        Args:\n            results (dict): A result dict contains the file name\n                of the image to be read.\n\n        Returns:\n            dict: ``results`` will be returned containing loaded image.\n        \"\"\"\n\n        if isinstance(results['img'], str):\n            results['filename'] = results['img']\n            results['ori_filename'] = results['img']\n        else:\n            results['filename'] = None\n            results['ori_filename'] = None\n        img = mmcv.imread(results['img'])\n        results['img'] = img\n        results['img_shape'] = img.shape\n        results['ori_shape'] = img.shape\n        return results\n\n\ndef inference_segmentor(model, img):\n    \"\"\"Inference image(s) with the segmentor.\n\n    Args:\n        model (nn.Module): The loaded segmentor.\n        imgs (str/ndarray or list[str/ndarray]): Either image files or loaded\n            images.\n\n    Returns:\n        (list[Tensor]): The segmentation result.\n    \"\"\"\n    cfg = model.cfg\n    device = next(model.parameters()).device  # model device\n    # build the data pipeline\n    test_pipeline = [LoadImage()] + cfg.data.test.pipeline[1:]\n    test_pipeline = Compose(test_pipeline)\n    # prepare data\n    data = dict(img=img)\n    data = test_pipeline(data)\n    data = collate([data], samples_per_gpu=1)\n    if next(model.parameters()).is_cuda:\n        # scatter to specified GPU\n        data = scatter(data, [device])[0]\n    else:\n        data['img_metas'] = [i.data[0] for i in data['img_metas']]\n\n    # forward the model\n    with torch.no_grad():\n        result = model(return_loss=False, rescale=True, **data)\n    return result\n\n\ndef show_result_pyplot(model, img, result, palette=None, fig_size=(15, 10)):\n    \"\"\"Visualize the segmentation results on the image.\n\n    Args:\n        model (nn.Module): The loaded segmentor.\n        img (str or np.ndarray): Image filename or loaded image.\n        result (list): The segmentation result.\n        palette (list[list[int]]] | None): The palette of segmentation\n            map. If None is given, random palette will be generated.\n            Default: None\n        fig_size (tuple): Figure size of the pyplot figure.\n    \"\"\"\n    if hasattr(model, 'module'):\n        model = model.module\n    img = model.show_result(img, result, palette=palette, show=False)\n    plt.figure(figsize=fig_size)\n    plt.imshow(mmcv.bgr2rgb(img))\n    plt.show()\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/mmseg/apis/test.py",
    "content": "import os.path as osp\nimport pickle\nimport shutil\nimport tempfile\n\nimport mmcv\nimport numpy as np\nimport torch\nimport torch.distributed as dist\nfrom mmcv.image import tensor2imgs\nfrom mmcv.runner import get_dist_info\n\n\ndef np2tmp(array, temp_file_name=None):\n    \"\"\"Save ndarray to local numpy file.\n\n    Args:\n        array (ndarray): Ndarray to save.\n        temp_file_name (str): Numpy file name. If 'temp_file_name=None', this\n            function will generate a file name with tempfile.NamedTemporaryFile\n            to save ndarray. Default: None.\n\n    Returns:\n        str: The numpy file name.\n    \"\"\"\n\n    if temp_file_name is None:\n        temp_file_name = tempfile.NamedTemporaryFile(\n            suffix='.npy', delete=False).name\n    np.save(temp_file_name, array)\n    return temp_file_name\n\n\ndef single_gpu_test(model,\n                    data_loader,\n                    show=False,\n                    out_dir=None,\n                    efficient_test=False):\n    \"\"\"Test with single GPU.\n\n    Args:\n        model (nn.Module): Model to be tested.\n        data_loader (utils.data.Dataloader): Pytorch data loader.\n        show (bool): Whether show results during infernece. Default: False.\n        out_dir (str, optional): If specified, the results will be dumped into\n            the directory to save output results.\n        efficient_test (bool): Whether save the results as local numpy files to\n            save CPU memory during evaluation. Default: False.\n\n    Returns:\n        list: The prediction results.\n    \"\"\"\n\n    model.eval()\n    results = []\n    dataset = data_loader.dataset\n    prog_bar = mmcv.ProgressBar(len(dataset))\n    for i, data in enumerate(data_loader):\n        with torch.no_grad():\n            result = model(return_loss=False, **data)\n\n        if show or out_dir:\n            img_tensor = data['img'][0]\n            img_metas = data['img_metas'][0].data[0]\n            imgs = tensor2imgs(img_tensor, **img_metas[0]['img_norm_cfg'])\n            assert len(imgs) == len(img_metas)\n\n            for img, img_meta in zip(imgs, img_metas):\n                h, w, _ = img_meta['img_shape']\n                img_show = img[:h, :w, :]\n\n                ori_h, ori_w = img_meta['ori_shape'][:-1]\n                img_show = mmcv.imresize(img_show, (ori_w, ori_h))\n\n                if out_dir:\n                    out_file = osp.join(out_dir, img_meta['ori_filename'])\n                else:\n                    out_file = None\n\n                model.module.show_result(\n                    img_show,\n                    result,\n                    palette=dataset.PALETTE,\n                    show=show,\n                    out_file=out_file)\n\n        if isinstance(result, list):\n            if efficient_test:\n                result = [np2tmp(_) for _ in result]\n            results.extend(result)\n        else:\n            if efficient_test:\n                result = np2tmp(result)\n            results.append(result)\n\n        batch_size = data['img'][0].size(0)\n        for _ in range(batch_size):\n            prog_bar.update()\n    return results\n\n\ndef multi_gpu_test(model,\n                   data_loader,\n                   tmpdir=None,\n                   gpu_collect=False,\n                   efficient_test=False):\n    \"\"\"Test model with multiple gpus.\n\n    This method tests model with multiple gpus and collects the results\n    under two different modes: gpu and cpu modes. By setting 'gpu_collect=True'\n    it encodes results to gpu tensors and use gpu communication for results\n    collection. On cpu mode it saves the results on different gpus to 'tmpdir'\n    and collects them by the rank 0 worker.\n\n    Args:\n        model (nn.Module): Model to be tested.\n        data_loader (utils.data.Dataloader): Pytorch data loader.\n        tmpdir (str): Path of directory to save the temporary results from\n            different gpus under cpu mode.\n        gpu_collect (bool): Option to use either gpu or cpu to collect results.\n        efficient_test (bool): Whether save the results as local numpy files to\n            save CPU memory during evaluation. Default: False.\n\n    Returns:\n        list: The prediction results.\n    \"\"\"\n\n    model.eval()\n    results = []\n    dataset = data_loader.dataset\n    rank, world_size = get_dist_info()\n    if rank == 0:\n        prog_bar = mmcv.ProgressBar(len(dataset))\n    for i, data in enumerate(data_loader):\n        with torch.no_grad():\n            result = model(return_loss=False, rescale=True, **data)\n\n        if isinstance(result, list):\n            if efficient_test:\n                result = [np2tmp(_) for _ in result]\n            results.extend(result)\n        else:\n            if efficient_test:\n                result = np2tmp(result)\n            results.append(result)\n\n        if rank == 0:\n            batch_size = data['img'][0].size(0)\n            for _ in range(batch_size * world_size):\n                prog_bar.update()\n\n    # collect results from all ranks\n    if gpu_collect:\n        results = collect_results_gpu(results, len(dataset))\n    else:\n        results = collect_results_cpu(results, len(dataset), tmpdir)\n    return results\n\n\ndef collect_results_cpu(result_part, size, tmpdir=None):\n    \"\"\"Collect results with CPU.\"\"\"\n    rank, world_size = get_dist_info()\n    # create a tmp dir if it is not specified\n    if tmpdir is None:\n        MAX_LEN = 512\n        # 32 is whitespace\n        dir_tensor = torch.full((MAX_LEN, ),\n                                32,\n                                dtype=torch.uint8,\n                                device='cuda')\n        if rank == 0:\n            tmpdir = tempfile.mkdtemp()\n            tmpdir = torch.tensor(\n                bytearray(tmpdir.encode()), dtype=torch.uint8, device='cuda')\n            dir_tensor[:len(tmpdir)] = tmpdir\n        dist.broadcast(dir_tensor, 0)\n        tmpdir = dir_tensor.cpu().numpy().tobytes().decode().rstrip()\n    else:\n        mmcv.mkdir_or_exist(tmpdir)\n    # dump the part result to the dir\n    mmcv.dump(result_part, osp.join(tmpdir, 'part_{}.pkl'.format(rank)))\n    dist.barrier()\n    # collect all parts\n    if rank != 0:\n        return None\n    else:\n        # load results of all parts from tmp dir\n        part_list = []\n        for i in range(world_size):\n            part_file = osp.join(tmpdir, 'part_{}.pkl'.format(i))\n            part_list.append(mmcv.load(part_file))\n        # sort the results\n        ordered_results = []\n        for res in zip(*part_list):\n            ordered_results.extend(list(res))\n        # the dataloader may pad some samples\n        ordered_results = ordered_results[:size]\n        # remove tmp dir\n        shutil.rmtree(tmpdir)\n        return ordered_results\n\n\ndef collect_results_gpu(result_part, size):\n    \"\"\"Collect results with GPU.\"\"\"\n    rank, world_size = get_dist_info()\n    # dump result part to tensor with pickle\n    part_tensor = torch.tensor(\n        bytearray(pickle.dumps(result_part)), dtype=torch.uint8, device='cuda')\n    # gather all result part tensor shape\n    shape_tensor = torch.tensor(part_tensor.shape, device='cuda')\n    shape_list = [shape_tensor.clone() for _ in range(world_size)]\n    dist.all_gather(shape_list, shape_tensor)\n    # padding result part tensor to max length\n    shape_max = torch.tensor(shape_list).max()\n    part_send = torch.zeros(shape_max, dtype=torch.uint8, device='cuda')\n    part_send[:shape_tensor[0]] = part_tensor\n    part_recv_list = [\n        part_tensor.new_zeros(shape_max) for _ in range(world_size)\n    ]\n    # gather all result part\n    dist.all_gather(part_recv_list, part_send)\n\n    if rank == 0:\n        part_list = []\n        for recv, shape in zip(part_recv_list, shape_list):\n            part_list.append(\n                pickle.loads(recv[:shape[0]].cpu().numpy().tobytes()))\n        # sort the results\n        ordered_results = []\n        for res in zip(*part_list):\n            ordered_results.extend(list(res))\n        # the dataloader may pad some samples\n        ordered_results = ordered_results[:size]\n        return ordered_results\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/mmseg/apis/train.py",
    "content": "import random\nimport warnings\n\nimport numpy as np\nimport torch\nfrom mmcv.parallel import MMDataParallel, MMDistributedDataParallel\nfrom mmcv.runner import build_optimizer, build_runner\n\nfrom mmseg.core import DistEvalHook, EvalHook\nfrom mmseg.datasets import build_dataloader, build_dataset\nfrom mmseg.utils import get_root_logger\n\n\ndef set_random_seed(seed, deterministic=False):\n    \"\"\"Set random seed.\n\n    Args:\n        seed (int): Seed to be used.\n        deterministic (bool): Whether to set the deterministic option for\n            CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`\n            to True and `torch.backends.cudnn.benchmark` to False.\n            Default: False.\n    \"\"\"\n    random.seed(seed)\n    np.random.seed(seed)\n    torch.manual_seed(seed)\n    torch.cuda.manual_seed_all(seed)\n    if deterministic:\n        torch.backends.cudnn.deterministic = True\n        torch.backends.cudnn.benchmark = False\n\n\ndef train_segmentor(model,\n                    dataset,\n                    cfg,\n                    distributed=False,\n                    validate=False,\n                    timestamp=None,\n                    meta=None):\n    \"\"\"Launch segmentor training.\"\"\"\n    logger = get_root_logger(cfg.log_level)\n\n    # prepare data loaders\n    dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]\n    data_loaders = [\n        build_dataloader(\n            ds,\n            cfg.data.samples_per_gpu,\n            cfg.data.workers_per_gpu,\n            # cfg.gpus will be ignored if distributed\n            len(cfg.gpu_ids),\n            dist=distributed,\n            seed=cfg.seed,\n            drop_last=True) for ds in dataset\n    ]\n\n    # put model on gpus\n    if distributed:\n        find_unused_parameters = cfg.get('find_unused_parameters', False)\n        # Sets the `find_unused_parameters` parameter in\n        # torch.nn.parallel.DistributedDataParallel\n        model = MMDistributedDataParallel(\n            model.cuda(),\n            device_ids=[torch.cuda.current_device()],\n            broadcast_buffers=False,\n            find_unused_parameters=find_unused_parameters)\n    else:\n        model = MMDataParallel(\n            model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids)\n\n    # build runner\n    optimizer = build_optimizer(model, cfg.optimizer)\n\n    if cfg.get('runner') is None:\n        cfg.runner = {'type': 'IterBasedRunner', 'max_iters': cfg.total_iters}\n        warnings.warn(\n            'config is now expected to have a `runner` section, '\n            'please set `runner` in your config.', UserWarning)\n\n    runner = build_runner(\n        cfg.runner,\n        default_args=dict(\n            model=model,\n            batch_processor=None,\n            optimizer=optimizer,\n            work_dir=cfg.work_dir,\n            logger=logger,\n            meta=meta))\n\n    # register hooks\n    runner.register_training_hooks(cfg.lr_config, cfg.optimizer_config,\n                                   cfg.checkpoint_config, cfg.log_config,\n                                   cfg.get('momentum_config', None))\n\n    # an ugly walkaround to make the .log and .log.json filenames the same\n    runner.timestamp = timestamp\n\n    # register eval hooks\n    if validate:\n        val_dataset = build_dataset(cfg.data.val, dict(test_mode=True))\n        val_dataloader = build_dataloader(\n            val_dataset,\n            samples_per_gpu=1,\n            workers_per_gpu=cfg.data.workers_per_gpu,\n            dist=distributed,\n            shuffle=False)\n        eval_cfg = cfg.get('evaluation', {})\n        eval_cfg['by_epoch'] = cfg.runner['type'] != 'IterBasedRunner'\n        eval_hook = DistEvalHook if distributed else EvalHook\n        runner.register_hook(eval_hook(val_dataloader, **eval_cfg))\n\n    if cfg.resume_from:\n        runner.resume(cfg.resume_from)\n    elif cfg.load_from:\n        runner.load_checkpoint(cfg.load_from)\n    runner.run(data_loaders, cfg.workflow)\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/mmseg/core/__init__.py",
    "content": "from .evaluation import *  # noqa: F401, F403\nfrom .seg import *  # noqa: F401, F403\nfrom .utils import *  # noqa: F401, F403\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/mmseg/core/evaluation/__init__.py",
    "content": "from .class_names import get_classes, get_palette\nfrom .eval_hooks import DistEvalHook, EvalHook\nfrom .metrics import eval_metrics, mean_dice, mean_iou\n\n__all__ = [\n    'EvalHook', 'DistEvalHook', 'mean_dice', 'mean_iou', 'eval_metrics',\n    'get_classes', 'get_palette'\n]\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/mmseg/core/evaluation/class_names.py",
    "content": "import mmcv\n\n\ndef cityscapes_classes():\n    \"\"\"Cityscapes class names for external use.\"\"\"\n    return [\n        'road', 'sidewalk', 'building', 'wall', 'fence', 'pole',\n        'traffic light', 'traffic sign', 'vegetation', 'terrain', 'sky',\n        'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle',\n        'bicycle'\n    ]\n\n\ndef ade_classes():\n    \"\"\"ADE20K class names for external use.\"\"\"\n    return [\n        'wall', 'building', 'sky', 'floor', 'tree', 'ceiling', 'road', 'bed ',\n        'windowpane', 'grass', 'cabinet', 'sidewalk', 'person', 'earth',\n        'door', 'table', 'mountain', 'plant', 'curtain', 'chair', 'car',\n        'water', 'painting', 'sofa', 'shelf', 'house', 'sea', 'mirror', 'rug',\n        'field', 'armchair', 'seat', 'fence', 'desk', 'rock', 'wardrobe',\n        'lamp', 'bathtub', 'railing', 'cushion', 'base', 'box', 'column',\n        'signboard', 'chest of drawers', 'counter', 'sand', 'sink',\n        'skyscraper', 'fireplace', 'refrigerator', 'grandstand', 'path',\n        'stairs', 'runway', 'case', 'pool table', 'pillow', 'screen door',\n        'stairway', 'river', 'bridge', 'bookcase', 'blind', 'coffee table',\n        'toilet', 'flower', 'book', 'hill', 'bench', 'countertop', 'stove',\n        'palm', 'kitchen island', 'computer', 'swivel chair', 'boat', 'bar',\n        'arcade machine', 'hovel', 'bus', 'towel', 'light', 'truck', 'tower',\n        'chandelier', 'awning', 'streetlight', 'booth', 'television receiver',\n        'airplane', 'dirt track', 'apparel', 'pole', 'land', 'bannister',\n        'escalator', 'ottoman', 'bottle', 'buffet', 'poster', 'stage', 'van',\n        'ship', 'fountain', 'conveyer belt', 'canopy', 'washer', 'plaything',\n        'swimming pool', 'stool', 'barrel', 'basket', 'waterfall', 'tent',\n        'bag', 'minibike', 'cradle', 'oven', 'ball', 'food', 'step', 'tank',\n        'trade name', 'microwave', 'pot', 'animal', 'bicycle', 'lake',\n        'dishwasher', 'screen', 'blanket', 'sculpture', 'hood', 'sconce',\n        'vase', 'traffic light', 'tray', 'ashcan', 'fan', 'pier', 'crt screen',\n        'plate', 'monitor', 'bulletin board', 'shower', 'radiator', 'glass',\n        'clock', 'flag'\n    ]\n\n\ndef voc_classes():\n    \"\"\"Pascal VOC class names for external use.\"\"\"\n    return [\n        'background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus',\n        'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse',\n        'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train',\n        'tvmonitor'\n    ]\n\n\ndef cityscapes_palette():\n    \"\"\"Cityscapes palette for external use.\"\"\"\n    return [[128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156],\n            [190, 153, 153], [153, 153, 153], [250, 170, 30], [220, 220, 0],\n            [107, 142, 35], [152, 251, 152], [70, 130, 180], [220, 20, 60],\n            [255, 0, 0], [0, 0, 142], [0, 0, 70], [0, 60, 100], [0, 80, 100],\n            [0, 0, 230], [119, 11, 32]]\n\n\ndef ade_palette():\n    \"\"\"ADE20K palette for external use.\"\"\"\n    return [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],\n            [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],\n            [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],\n            [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],\n            [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3],\n            [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],\n            [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220],\n            [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],\n            [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255],\n            [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],\n            [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153],\n            [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],\n            [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],\n            [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],\n            [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255],\n            [11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255],\n            [0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0],\n            [255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0],\n            [0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255],\n            [173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255],\n            [255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20],\n            [255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255],\n            [255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255],\n            [0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255],\n            [0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0],\n            [143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0],\n            [8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255],\n            [255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112],\n            [92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160],\n            [163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163],\n            [255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0],\n            [255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0],\n            [10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255],\n            [255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204],\n            [41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255],\n            [71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255],\n            [184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194],\n            [102, 255, 0], [92, 0, 255]]\n\n\ndef voc_palette():\n    \"\"\"Pascal VOC palette for external use.\"\"\"\n    return [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], [0, 0, 128],\n            [128, 0, 128], [0, 128, 128], [128, 128, 128], [64, 0, 0],\n            [192, 0, 0], [64, 128, 0], [192, 128, 0], [64, 0, 128],\n            [192, 0, 128], [64, 128, 128], [192, 128, 128], [0, 64, 0],\n            [128, 64, 0], [0, 192, 0], [128, 192, 0], [0, 64, 128]]\n\n\ndataset_aliases = {\n    'cityscapes': ['cityscapes'],\n    'ade': ['ade', 'ade20k'],\n    'voc': ['voc', 'pascal_voc', 'voc12', 'voc12aug']\n}\n\n\ndef get_classes(dataset):\n    \"\"\"Get class names of a dataset.\"\"\"\n    alias2name = {}\n    for name, aliases in dataset_aliases.items():\n        for alias in aliases:\n            alias2name[alias] = name\n\n    if mmcv.is_str(dataset):\n        if dataset in alias2name:\n            labels = eval(alias2name[dataset] + '_classes()')\n        else:\n            raise ValueError(f'Unrecognized dataset: {dataset}')\n    else:\n        raise TypeError(f'dataset must a str, but got {type(dataset)}')\n    return labels\n\n\ndef get_palette(dataset):\n    \"\"\"Get class palette (RGB) of a dataset.\"\"\"\n    alias2name = {}\n    for name, aliases in dataset_aliases.items():\n        for alias in aliases:\n            alias2name[alias] = name\n\n    if mmcv.is_str(dataset):\n        if dataset in alias2name:\n            labels = eval(alias2name[dataset] + '_palette()')\n        else:\n            raise ValueError(f'Unrecognized dataset: {dataset}')\n    else:\n        raise TypeError(f'dataset must a str, but got {type(dataset)}')\n    return labels\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/mmseg/core/evaluation/eval_hooks.py",
    "content": "import os.path as osp\n\nfrom mmcv.runner import Hook\nfrom torch.utils.data import DataLoader\n\n\nclass EvalHook(Hook):\n    \"\"\"Evaluation hook.\n\n    Attributes:\n        dataloader (DataLoader): A PyTorch dataloader.\n        interval (int): Evaluation interval (by epochs). Default: 1.\n    \"\"\"\n\n    def __init__(self, dataloader, interval=1, by_epoch=False, **eval_kwargs):\n        if not isinstance(dataloader, DataLoader):\n            raise TypeError('dataloader must be a pytorch DataLoader, but got '\n                            f'{type(dataloader)}')\n        self.dataloader = dataloader\n        self.interval = interval\n        self.by_epoch = by_epoch\n        self.eval_kwargs = eval_kwargs\n\n    def after_train_iter(self, runner):\n        \"\"\"After train epoch hook.\"\"\"\n        if self.by_epoch or not self.every_n_iters(runner, self.interval):\n            return\n        from mmseg.apis import single_gpu_test\n        runner.log_buffer.clear()\n        results = single_gpu_test(runner.model, self.dataloader, show=False)\n        self.evaluate(runner, results)\n\n    def after_train_epoch(self, runner):\n        \"\"\"After train epoch hook.\"\"\"\n        if not self.by_epoch or not self.every_n_epochs(runner, self.interval):\n            return\n        from mmseg.apis import single_gpu_test\n        runner.log_buffer.clear()\n        results = single_gpu_test(runner.model, self.dataloader, show=False)\n        self.evaluate(runner, results)\n\n    def evaluate(self, runner, results):\n        \"\"\"Call evaluate function of dataset.\"\"\"\n        eval_res = self.dataloader.dataset.evaluate(\n            results, logger=runner.logger, **self.eval_kwargs)\n        for name, val in eval_res.items():\n            runner.log_buffer.output[name] = val\n        runner.log_buffer.ready = True\n\n\nclass DistEvalHook(EvalHook):\n    \"\"\"Distributed evaluation hook.\n\n    Attributes:\n        dataloader (DataLoader): A PyTorch dataloader.\n        interval (int): Evaluation interval (by epochs). Default: 1.\n        tmpdir (str | None): Temporary directory to save the results of all\n            processes. Default: None.\n        gpu_collect (bool): Whether to use gpu or cpu to collect results.\n            Default: False.\n    \"\"\"\n\n    def __init__(self,\n                 dataloader,\n                 interval=1,\n                 gpu_collect=False,\n                 by_epoch=False,\n                 **eval_kwargs):\n        if not isinstance(dataloader, DataLoader):\n            raise TypeError(\n                'dataloader must be a pytorch DataLoader, but got {}'.format(\n                    type(dataloader)))\n        self.dataloader = dataloader\n        self.interval = interval\n        self.gpu_collect = gpu_collect\n        self.by_epoch = by_epoch\n        self.eval_kwargs = eval_kwargs\n\n    def after_train_iter(self, runner):\n        \"\"\"After train epoch hook.\"\"\"\n        if self.by_epoch or not self.every_n_iters(runner, self.interval):\n            return\n        from mmseg.apis import multi_gpu_test\n        runner.log_buffer.clear()\n        results = multi_gpu_test(\n            runner.model,\n            self.dataloader,\n            tmpdir=osp.join(runner.work_dir, '.eval_hook'),\n            gpu_collect=self.gpu_collect)\n        if runner.rank == 0:\n            print('\\n')\n            self.evaluate(runner, results)\n\n    def after_train_epoch(self, runner):\n        \"\"\"After train epoch hook.\"\"\"\n        if not self.by_epoch or not self.every_n_epochs(runner, self.interval):\n            return\n        from mmseg.apis import multi_gpu_test\n        runner.log_buffer.clear()\n        results = multi_gpu_test(\n            runner.model,\n            self.dataloader,\n            tmpdir=osp.join(runner.work_dir, '.eval_hook'),\n            gpu_collect=self.gpu_collect)\n        if runner.rank == 0:\n            print('\\n')\n            self.evaluate(runner, results)\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/mmseg/core/evaluation/metrics.py",
    "content": "import mmcv\nimport numpy as np\n\n\ndef intersect_and_union(pred_label,\n                        label,\n                        num_classes,\n                        ignore_index,\n                        label_map=dict(),\n                        reduce_zero_label=False):\n    \"\"\"Calculate intersection and Union.\n\n    Args:\n        pred_label (ndarray): Prediction segmentation map.\n        label (ndarray): Ground truth segmentation map.\n        num_classes (int): Number of categories.\n        ignore_index (int): Index that will be ignored in evaluation.\n        label_map (dict): Mapping old labels to new labels. The parameter will\n            work only when label is str. Default: dict().\n        reduce_zero_label (bool): Wether ignore zero label. The parameter will\n            work only when label is str. Default: False.\n\n     Returns:\n         ndarray: The intersection of prediction and ground truth histogram\n             on all classes.\n         ndarray: The union of prediction and ground truth histogram on all\n             classes.\n         ndarray: The prediction histogram on all classes.\n         ndarray: The ground truth histogram on all classes.\n    \"\"\"\n\n    if isinstance(pred_label, str):\n        pred_label = np.load(pred_label)\n\n    if isinstance(label, str):\n        label = mmcv.imread(label, flag='unchanged', backend='pillow')\n    # modify if custom classes\n    if label_map is not None:\n        for old_id, new_id in label_map.items():\n            label[label == old_id] = new_id\n    if reduce_zero_label:\n        # avoid using underflow conversion\n        label[label == 0] = 255\n        label = label - 1\n        label[label == 254] = 255\n\n    mask = (label != ignore_index)\n    pred_label = pred_label[mask]\n    label = label[mask]\n\n    intersect = pred_label[pred_label == label]\n    area_intersect, _ = np.histogram(\n        intersect, bins=np.arange(num_classes + 1))\n    area_pred_label, _ = np.histogram(\n        pred_label, bins=np.arange(num_classes + 1))\n    area_label, _ = np.histogram(label, bins=np.arange(num_classes + 1))\n    area_union = area_pred_label + area_label - area_intersect\n\n    return area_intersect, area_union, area_pred_label, area_label\n\n\ndef total_intersect_and_union(results,\n                              gt_seg_maps,\n                              num_classes,\n                              ignore_index,\n                              label_map=dict(),\n                              reduce_zero_label=False):\n    \"\"\"Calculate Total Intersection and Union.\n\n    Args:\n        results (list[ndarray]): List of prediction segmentation maps.\n        gt_seg_maps (list[ndarray]): list of ground truth segmentation maps.\n        num_classes (int): Number of categories.\n        ignore_index (int): Index that will be ignored in evaluation.\n        label_map (dict): Mapping old labels to new labels. Default: dict().\n        reduce_zero_label (bool): Wether ignore zero label. Default: False.\n\n     Returns:\n         ndarray: The intersection of prediction and ground truth histogram\n             on all classes.\n         ndarray: The union of prediction and ground truth histogram on all\n             classes.\n         ndarray: The prediction histogram on all classes.\n         ndarray: The ground truth histogram on all classes.\n    \"\"\"\n\n    num_imgs = len(results)\n    assert len(gt_seg_maps) == num_imgs\n    total_area_intersect = np.zeros((num_classes, ), dtype=np.float)\n    total_area_union = np.zeros((num_classes, ), dtype=np.float)\n    total_area_pred_label = np.zeros((num_classes, ), dtype=np.float)\n    total_area_label = np.zeros((num_classes, ), dtype=np.float)\n    for i in range(num_imgs):\n        area_intersect, area_union, area_pred_label, area_label = \\\n            intersect_and_union(results[i], gt_seg_maps[i], num_classes,\n                                ignore_index, label_map, reduce_zero_label)\n        total_area_intersect += area_intersect\n        total_area_union += area_union\n        total_area_pred_label += area_pred_label\n        total_area_label += area_label\n    return total_area_intersect, total_area_union, \\\n        total_area_pred_label, total_area_label\n\n\ndef mean_iou(results,\n             gt_seg_maps,\n             num_classes,\n             ignore_index,\n             nan_to_num=None,\n             label_map=dict(),\n             reduce_zero_label=False):\n    \"\"\"Calculate Mean Intersection and Union (mIoU)\n\n    Args:\n        results (list[ndarray]): List of prediction segmentation maps.\n        gt_seg_maps (list[ndarray]): list of ground truth segmentation maps.\n        num_classes (int): Number of categories.\n        ignore_index (int): Index that will be ignored in evaluation.\n        nan_to_num (int, optional): If specified, NaN values will be replaced\n            by the numbers defined by the user. Default: None.\n        label_map (dict): Mapping old labels to new labels. Default: dict().\n        reduce_zero_label (bool): Wether ignore zero label. Default: False.\n\n     Returns:\n         float: Overall accuracy on all images.\n         ndarray: Per category accuracy, shape (num_classes, ).\n         ndarray: Per category IoU, shape (num_classes, ).\n    \"\"\"\n\n    all_acc, acc, iou = eval_metrics(\n        results=results,\n        gt_seg_maps=gt_seg_maps,\n        num_classes=num_classes,\n        ignore_index=ignore_index,\n        metrics=['mIoU'],\n        nan_to_num=nan_to_num,\n        label_map=label_map,\n        reduce_zero_label=reduce_zero_label)\n    return all_acc, acc, iou\n\n\ndef mean_dice(results,\n              gt_seg_maps,\n              num_classes,\n              ignore_index,\n              nan_to_num=None,\n              label_map=dict(),\n              reduce_zero_label=False):\n    \"\"\"Calculate Mean Dice (mDice)\n\n    Args:\n        results (list[ndarray]): List of prediction segmentation maps.\n        gt_seg_maps (list[ndarray]): list of ground truth segmentation maps.\n        num_classes (int): Number of categories.\n        ignore_index (int): Index that will be ignored in evaluation.\n        nan_to_num (int, optional): If specified, NaN values will be replaced\n            by the numbers defined by the user. Default: None.\n        label_map (dict): Mapping old labels to new labels. Default: dict().\n        reduce_zero_label (bool): Wether ignore zero label. Default: False.\n\n     Returns:\n         float: Overall accuracy on all images.\n         ndarray: Per category accuracy, shape (num_classes, ).\n         ndarray: Per category dice, shape (num_classes, ).\n    \"\"\"\n\n    all_acc, acc, dice = eval_metrics(\n        results=results,\n        gt_seg_maps=gt_seg_maps,\n        num_classes=num_classes,\n        ignore_index=ignore_index,\n        metrics=['mDice'],\n        nan_to_num=nan_to_num,\n        label_map=label_map,\n        reduce_zero_label=reduce_zero_label)\n    return all_acc, acc, dice\n\n\ndef eval_metrics(results,\n                 gt_seg_maps,\n                 num_classes,\n                 ignore_index,\n                 metrics=['mIoU'],\n                 nan_to_num=None,\n                 label_map=dict(),\n                 reduce_zero_label=False):\n    \"\"\"Calculate evaluation metrics\n    Args:\n        results (list[ndarray]): List of prediction segmentation maps.\n        gt_seg_maps (list[ndarray]): list of ground truth segmentation maps.\n        num_classes (int): Number of categories.\n        ignore_index (int): Index that will be ignored in evaluation.\n        metrics (list[str] | str): Metrics to be evaluated, 'mIoU' and 'mDice'.\n        nan_to_num (int, optional): If specified, NaN values will be replaced\n            by the numbers defined by the user. Default: None.\n        label_map (dict): Mapping old labels to new labels. Default: dict().\n        reduce_zero_label (bool): Wether ignore zero label. Default: False.\n     Returns:\n         float: Overall accuracy on all images.\n         ndarray: Per category accuracy, shape (num_classes, ).\n         ndarray: Per category evalution metrics, shape (num_classes, ).\n    \"\"\"\n\n    if isinstance(metrics, str):\n        metrics = [metrics]\n    allowed_metrics = ['mIoU', 'mDice']\n    if not set(metrics).issubset(set(allowed_metrics)):\n        raise KeyError('metrics {} is not supported'.format(metrics))\n    total_area_intersect, total_area_union, total_area_pred_label, \\\n        total_area_label = total_intersect_and_union(results, gt_seg_maps,\n                                                     num_classes, ignore_index,\n                                                     label_map,\n                                                     reduce_zero_label)\n    all_acc = total_area_intersect.sum() / total_area_label.sum()\n    acc = total_area_intersect / total_area_label\n    ret_metrics = [all_acc, acc]\n    for metric in metrics:\n        if metric == 'mIoU':\n            iou = total_area_intersect / total_area_union\n            ret_metrics.append(iou)\n        elif metric == 'mDice':\n            dice = 2 * total_area_intersect / (\n                total_area_pred_label + total_area_label)\n            ret_metrics.append(dice)\n    if nan_to_num is not None:\n        ret_metrics = [\n            np.nan_to_num(metric, nan=nan_to_num) for metric in ret_metrics\n        ]\n    return ret_metrics\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/mmseg/core/seg/__init__.py",
    "content": "from .builder import build_pixel_sampler\nfrom .sampler import BasePixelSampler, OHEMPixelSampler\n\n__all__ = ['build_pixel_sampler', 'BasePixelSampler', 'OHEMPixelSampler']\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/mmseg/core/seg/builder.py",
    "content": "from mmcv.utils import Registry, build_from_cfg\n\nPIXEL_SAMPLERS = Registry('pixel sampler')\n\n\ndef build_pixel_sampler(cfg, **default_args):\n    \"\"\"Build pixel sampler for segmentation map.\"\"\"\n    return build_from_cfg(cfg, PIXEL_SAMPLERS, default_args)\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/mmseg/core/seg/sampler/__init__.py",
    "content": "from .base_pixel_sampler import BasePixelSampler\nfrom .ohem_pixel_sampler import OHEMPixelSampler\n\n__all__ = ['BasePixelSampler', 'OHEMPixelSampler']\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/mmseg/core/seg/sampler/base_pixel_sampler.py",
    "content": "from abc import ABCMeta, abstractmethod\n\n\nclass BasePixelSampler(metaclass=ABCMeta):\n    \"\"\"Base class of pixel sampler.\"\"\"\n\n    def __init__(self, **kwargs):\n        pass\n\n    @abstractmethod\n    def sample(self, seg_logit, seg_label):\n        \"\"\"Placeholder for sample function.\"\"\"\n        pass\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/mmseg/core/seg/sampler/ohem_pixel_sampler.py",
    "content": "import torch\nimport torch.nn.functional as F\n\nfrom ..builder import PIXEL_SAMPLERS\nfrom .base_pixel_sampler import BasePixelSampler\n\n\n@PIXEL_SAMPLERS.register_module()\nclass OHEMPixelSampler(BasePixelSampler):\n    \"\"\"Online Hard Example Mining Sampler for segmentation.\n\n    Args:\n        context (nn.Module): The context of sampler, subclass of\n            :obj:`BaseDecodeHead`.\n        thresh (float, optional): The threshold for hard example selection.\n            Below which, are prediction with low confidence. If not\n            specified, the hard examples will be pixels of top ``min_kept``\n            loss. Default: None.\n        min_kept (int, optional): The minimum number of predictions to keep.\n            Default: 100000.\n    \"\"\"\n\n    def __init__(self, context, thresh=None, min_kept=100000):\n        super(OHEMPixelSampler, self).__init__()\n        self.context = context\n        assert min_kept > 1\n        self.thresh = thresh\n        self.min_kept = min_kept\n\n    def sample(self, seg_logit, seg_label):\n        \"\"\"Sample pixels that have high loss or with low prediction confidence.\n\n        Args:\n            seg_logit (torch.Tensor): segmentation logits, shape (N, C, H, W)\n            seg_label (torch.Tensor): segmentation label, shape (N, 1, H, W)\n\n        Returns:\n            torch.Tensor: segmentation weight, shape (N, H, W)\n        \"\"\"\n        with torch.no_grad():\n            assert seg_logit.shape[2:] == seg_label.shape[2:]\n            assert seg_label.shape[1] == 1\n            seg_label = seg_label.squeeze(1).long()\n            batch_kept = self.min_kept * seg_label.size(0)\n            valid_mask = seg_label != self.context.ignore_index\n            seg_weight = seg_logit.new_zeros(size=seg_label.size())\n            valid_seg_weight = seg_weight[valid_mask]\n            if self.thresh is not None:\n                seg_prob = F.softmax(seg_logit, dim=1)\n\n                tmp_seg_label = seg_label.clone().unsqueeze(1)\n                tmp_seg_label[tmp_seg_label == self.context.ignore_index] = 0\n                seg_prob = seg_prob.gather(1, tmp_seg_label).squeeze(1)\n                sort_prob, sort_indices = seg_prob[valid_mask].sort()\n\n                if sort_prob.numel() > 0:\n                    min_threshold = sort_prob[min(batch_kept,\n                                                  sort_prob.numel() - 1)]\n                else:\n                    min_threshold = 0.0\n                threshold = max(min_threshold, self.thresh)\n                valid_seg_weight[seg_prob[valid_mask] < threshold] = 1.\n            else:\n                losses = self.context.loss_decode(\n                    seg_logit,\n                    seg_label,\n                    weight=None,\n                    ignore_index=self.context.ignore_index,\n                    reduction_override='none')\n                # faster than topk according to https://github.com/pytorch/pytorch/issues/22812  # noqa\n                _, sort_indices = losses[valid_mask].sort(descending=True)\n                valid_seg_weight[sort_indices[:batch_kept]] = 1.\n\n            seg_weight[valid_mask] = valid_seg_weight\n\n            return seg_weight\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/mmseg/core/utils/__init__.py",
    "content": "from .misc import add_prefix\n\n__all__ = ['add_prefix']\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/mmseg/core/utils/misc.py",
    "content": "def add_prefix(inputs, prefix):\n    \"\"\"Add prefix for dict.\n\n    Args:\n        inputs (dict): The input dict with str keys.\n        prefix (str): The prefix to add.\n\n    Returns:\n\n        dict: The dict with keys updated with ``prefix``.\n    \"\"\"\n\n    outputs = dict()\n    for name, value in inputs.items():\n        outputs[f'{prefix}.{name}'] = value\n\n    return outputs\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/mmseg/datasets/__init__.py",
    "content": "from .ade import ADE20KDataset\nfrom .builder import DATASETS, PIPELINES, build_dataloader, build_dataset\nfrom .chase_db1 import ChaseDB1Dataset\nfrom .cityscapes import CityscapesDataset\nfrom .custom import CustomDataset\nfrom .dataset_wrappers import ConcatDataset, RepeatDataset\nfrom .drive import DRIVEDataset\nfrom .hrf import HRFDataset\nfrom .pascal_context import PascalContextDataset\nfrom .stare import STAREDataset\nfrom .voc import PascalVOCDataset\nfrom .coco_stuff import COCOStuffDataset\n\n__all__ = [\n    'CustomDataset', 'build_dataloader', 'ConcatDataset', 'RepeatDataset',\n    'DATASETS', 'build_dataset', 'PIPELINES', 'CityscapesDataset',\n    'PascalVOCDataset', 'ADE20KDataset', 'PascalContextDataset',\n    'ChaseDB1Dataset', 'DRIVEDataset', 'HRFDataset', 'STAREDataset', 'COCOStuffDataset',\n]\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/mmseg/datasets/ade.py",
    "content": "from .builder import DATASETS\nfrom .custom import CustomDataset\n\n\n@DATASETS.register_module()\nclass ADE20KDataset(CustomDataset):\n    \"\"\"ADE20K dataset.\n\n    In segmentation map annotation for ADE20K, 0 stands for background, which\n    is not included in 150 categories. ``reduce_zero_label`` is fixed to True.\n    The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is fixed to\n    '.png'.\n    \"\"\"\n    CLASSES = (\n        'wall', 'building', 'sky', 'floor', 'tree', 'ceiling', 'road', 'bed ',\n        'windowpane', 'grass', 'cabinet', 'sidewalk', 'person', 'earth',\n        'door', 'table', 'mountain', 'plant', 'curtain', 'chair', 'car',\n        'water', 'painting', 'sofa', 'shelf', 'house', 'sea', 'mirror', 'rug',\n        'field', 'armchair', 'seat', 'fence', 'desk', 'rock', 'wardrobe',\n        'lamp', 'bathtub', 'railing', 'cushion', 'base', 'box', 'column',\n        'signboard', 'chest of drawers', 'counter', 'sand', 'sink',\n        'skyscraper', 'fireplace', 'refrigerator', 'grandstand', 'path',\n        'stairs', 'runway', 'case', 'pool table', 'pillow', 'screen door',\n        'stairway', 'river', 'bridge', 'bookcase', 'blind', 'coffee table',\n        'toilet', 'flower', 'book', 'hill', 'bench', 'countertop', 'stove',\n        'palm', 'kitchen island', 'computer', 'swivel chair', 'boat', 'bar',\n        'arcade machine', 'hovel', 'bus', 'towel', 'light', 'truck', 'tower',\n        'chandelier', 'awning', 'streetlight', 'booth', 'television receiver',\n        'airplane', 'dirt track', 'apparel', 'pole', 'land', 'bannister',\n        'escalator', 'ottoman', 'bottle', 'buffet', 'poster', 'stage', 'van',\n        'ship', 'fountain', 'conveyer belt', 'canopy', 'washer', 'plaything',\n        'swimming pool', 'stool', 'barrel', 'basket', 'waterfall', 'tent',\n        'bag', 'minibike', 'cradle', 'oven', 'ball', 'food', 'step', 'tank',\n        'trade name', 'microwave', 'pot', 'animal', 'bicycle', 'lake',\n        'dishwasher', 'screen', 'blanket', 'sculpture', 'hood', 'sconce',\n        'vase', 'traffic light', 'tray', 'ashcan', 'fan', 'pier', 'crt screen',\n        'plate', 'monitor', 'bulletin board', 'shower', 'radiator', 'glass',\n        'clock', 'flag')\n\n    PALETTE = [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],\n               [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],\n               [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],\n               [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],\n               [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3],\n               [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],\n               [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220],\n               [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],\n               [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255],\n               [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],\n               [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153],\n               [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],\n               [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],\n               [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],\n               [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255],\n               [11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255],\n               [0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0],\n               [255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0],\n               [0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255],\n               [173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255],\n               [255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20],\n               [255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255],\n               [255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255],\n               [0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255],\n               [0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0],\n               [143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0],\n               [8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255],\n               [255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112],\n               [92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160],\n               [163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163],\n               [255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0],\n               [255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0],\n               [10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255],\n               [255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204],\n               [41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255],\n               [71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255],\n               [184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194],\n               [102, 255, 0], [92, 0, 255]]\n\n    def __init__(self, **kwargs):\n        super(ADE20KDataset, self).__init__(\n            img_suffix='.jpg',\n            seg_map_suffix='.png',\n            reduce_zero_label=True,\n            **kwargs)\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/mmseg/datasets/builder.py",
    "content": "import copy\nimport platform\nimport random\nfrom functools import partial\n\nimport numpy as np\nfrom mmcv.parallel import collate\nfrom mmcv.runner import get_dist_info\nfrom mmcv.utils import Registry, build_from_cfg\nfrom mmcv.utils.parrots_wrapper import DataLoader, PoolDataLoader\nfrom torch.utils.data import DistributedSampler\n\nif platform.system() != 'Windows':\n    # https://github.com/pytorch/pytorch/issues/973\n    import resource\n    rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)\n    hard_limit = rlimit[1]\n    soft_limit = min(4096, hard_limit)\n    resource.setrlimit(resource.RLIMIT_NOFILE, (soft_limit, hard_limit))\n\nDATASETS = Registry('dataset')\nPIPELINES = Registry('pipeline')\n\n\ndef _concat_dataset(cfg, default_args=None):\n    \"\"\"Build :obj:`ConcatDataset by.\"\"\"\n    from .dataset_wrappers import ConcatDataset\n    img_dir = cfg['img_dir']\n    ann_dir = cfg.get('ann_dir', None)\n    split = cfg.get('split', None)\n    num_img_dir = len(img_dir) if isinstance(img_dir, (list, tuple)) else 1\n    if ann_dir is not None:\n        num_ann_dir = len(ann_dir) if isinstance(ann_dir, (list, tuple)) else 1\n    else:\n        num_ann_dir = 0\n    if split is not None:\n        num_split = len(split) if isinstance(split, (list, tuple)) else 1\n    else:\n        num_split = 0\n    if num_img_dir > 1:\n        assert num_img_dir == num_ann_dir or num_ann_dir == 0\n        assert num_img_dir == num_split or num_split == 0\n    else:\n        assert num_split == num_ann_dir or num_ann_dir <= 1\n    num_dset = max(num_split, num_img_dir)\n\n    datasets = []\n    for i in range(num_dset):\n        data_cfg = copy.deepcopy(cfg)\n        if isinstance(img_dir, (list, tuple)):\n            data_cfg['img_dir'] = img_dir[i]\n        if isinstance(ann_dir, (list, tuple)):\n            data_cfg['ann_dir'] = ann_dir[i]\n        if isinstance(split, (list, tuple)):\n            data_cfg['split'] = split[i]\n        datasets.append(build_dataset(data_cfg, default_args))\n\n    return ConcatDataset(datasets)\n\n\ndef build_dataset(cfg, default_args=None):\n    \"\"\"Build datasets.\"\"\"\n    from .dataset_wrappers import ConcatDataset, RepeatDataset\n    if isinstance(cfg, (list, tuple)):\n        dataset = ConcatDataset([build_dataset(c, default_args) for c in cfg])\n    elif cfg['type'] == 'RepeatDataset':\n        dataset = RepeatDataset(\n            build_dataset(cfg['dataset'], default_args), cfg['times'])\n    elif isinstance(cfg.get('img_dir'), (list, tuple)) or isinstance(\n            cfg.get('split', None), (list, tuple)):\n        dataset = _concat_dataset(cfg, default_args)\n    else:\n        dataset = build_from_cfg(cfg, DATASETS, default_args)\n\n    return dataset\n\n\ndef build_dataloader(dataset,\n                     samples_per_gpu,\n                     workers_per_gpu,\n                     num_gpus=1,\n                     dist=True,\n                     shuffle=True,\n                     seed=None,\n                     drop_last=False,\n                     pin_memory=True,\n                     dataloader_type='PoolDataLoader',\n                     **kwargs):\n    \"\"\"Build PyTorch DataLoader.\n\n    In distributed training, each GPU/process has a dataloader.\n    In non-distributed training, there is only one dataloader for all GPUs.\n\n    Args:\n        dataset (Dataset): A PyTorch dataset.\n        samples_per_gpu (int): Number of training samples on each GPU, i.e.,\n            batch size of each GPU.\n        workers_per_gpu (int): How many subprocesses to use for data loading\n            for each GPU.\n        num_gpus (int): Number of GPUs. Only used in non-distributed training.\n        dist (bool): Distributed training/test or not. Default: True.\n        shuffle (bool): Whether to shuffle the data at every epoch.\n            Default: True.\n        seed (int | None): Seed to be used. Default: None.\n        drop_last (bool): Whether to drop the last incomplete batch in epoch.\n            Default: False\n        pin_memory (bool): Whether to use pin_memory in DataLoader.\n            Default: True\n        dataloader_type (str): Type of dataloader. Default: 'PoolDataLoader'\n        kwargs: any keyword argument to be used to initialize DataLoader\n\n    Returns:\n        DataLoader: A PyTorch dataloader.\n    \"\"\"\n    rank, world_size = get_dist_info()\n    if dist:\n        sampler = DistributedSampler(\n            dataset, world_size, rank, shuffle=shuffle)\n        shuffle = False\n        batch_size = samples_per_gpu\n        num_workers = workers_per_gpu\n    else:\n        sampler = None\n        batch_size = num_gpus * samples_per_gpu\n        num_workers = num_gpus * workers_per_gpu\n\n    init_fn = partial(\n        worker_init_fn, num_workers=num_workers, rank=rank,\n        seed=seed) if seed is not None else None\n\n    assert dataloader_type in (\n        'DataLoader',\n        'PoolDataLoader'), f'unsupported dataloader {dataloader_type}'\n\n    if dataloader_type == 'PoolDataLoader':\n        dataloader = PoolDataLoader\n    elif dataloader_type == 'DataLoader':\n        dataloader = DataLoader\n\n    data_loader = dataloader(\n        dataset,\n        batch_size=batch_size,\n        sampler=sampler,\n        num_workers=num_workers,\n        collate_fn=partial(collate, samples_per_gpu=samples_per_gpu),\n        pin_memory=pin_memory,\n        shuffle=shuffle,\n        worker_init_fn=init_fn,\n        drop_last=drop_last,\n        **kwargs)\n\n    return data_loader\n\n\ndef worker_init_fn(worker_id, num_workers, rank, seed):\n    \"\"\"Worker init func for dataloader.\n\n    The seed of each worker equals to num_worker * rank + worker_id + user_seed\n\n    Args:\n        worker_id (int): Worker id.\n        num_workers (int): Number of workers.\n        rank (int): The rank of current process.\n        seed (int): The random seed to use.\n    \"\"\"\n\n    worker_seed = num_workers * rank + worker_id + seed\n    np.random.seed(worker_seed)\n    random.seed(worker_seed)\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/mmseg/datasets/chase_db1.py",
    "content": "import os.path as osp\n\nfrom .builder import DATASETS\nfrom .custom import CustomDataset\n\n\n@DATASETS.register_module()\nclass ChaseDB1Dataset(CustomDataset):\n    \"\"\"Chase_db1 dataset.\n\n    In segmentation map annotation for Chase_db1, 0 stands for background,\n    which is included in 2 categories. ``reduce_zero_label`` is fixed to False.\n    The ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to\n    '_1stHO.png'.\n    \"\"\"\n\n    CLASSES = ('background', 'vessel')\n\n    PALETTE = [[120, 120, 120], [6, 230, 230]]\n\n    def __init__(self, **kwargs):\n        super(ChaseDB1Dataset, self).__init__(\n            img_suffix='.png',\n            seg_map_suffix='_1stHO.png',\n            reduce_zero_label=False,\n            **kwargs)\n        assert osp.exists(self.img_dir)\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/mmseg/datasets/cityscapes.py",
    "content": "import os.path as osp\nimport tempfile\n\nimport mmcv\nimport numpy as np\nfrom mmcv.utils import print_log\nfrom PIL import Image\n\nfrom .builder import DATASETS\nfrom .custom import CustomDataset\n\n\n@DATASETS.register_module()\nclass CityscapesDataset(CustomDataset):\n    \"\"\"Cityscapes dataset.\n\n    The ``img_suffix`` is fixed to '_leftImg8bit.png' and ``seg_map_suffix`` is\n    fixed to '_gtFine_labelTrainIds.png' for Cityscapes dataset.\n    \"\"\"\n\n    CLASSES = ('road', 'sidewalk', 'building', 'wall', 'fence', 'pole',\n               'traffic light', 'traffic sign', 'vegetation', 'terrain', 'sky',\n               'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle',\n               'bicycle')\n\n    PALETTE = [[128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156],\n               [190, 153, 153], [153, 153, 153], [250, 170, 30], [220, 220, 0],\n               [107, 142, 35], [152, 251, 152], [70, 130, 180], [220, 20, 60],\n               [255, 0, 0], [0, 0, 142], [0, 0, 70], [0, 60, 100],\n               [0, 80, 100], [0, 0, 230], [119, 11, 32]]\n\n    def __init__(self, **kwargs):\n        super(CityscapesDataset, self).__init__(\n            img_suffix='_leftImg8bit.png',\n            seg_map_suffix='_gtFine_labelTrainIds.png',\n            **kwargs)\n\n    @staticmethod\n    def _convert_to_label_id(result):\n        \"\"\"Convert trainId to id for cityscapes.\"\"\"\n        if isinstance(result, str):\n            result = np.load(result)\n        import cityscapesscripts.helpers.labels as CSLabels\n        result_copy = result.copy()\n        for trainId, label in CSLabels.trainId2label.items():\n            result_copy[result == trainId] = label.id\n\n        return result_copy\n\n    def results2img(self, results, imgfile_prefix, to_label_id):\n        \"\"\"Write the segmentation results to images.\n\n        Args:\n            results (list[list | tuple | ndarray]): Testing results of the\n                dataset.\n            imgfile_prefix (str): The filename prefix of the png files.\n                If the prefix is \"somepath/xxx\",\n                the png files will be named \"somepath/xxx.png\".\n            to_label_id (bool): whether convert output to label_id for\n                submission\n\n        Returns:\n            list[str: str]: result txt files which contains corresponding\n            semantic segmentation images.\n        \"\"\"\n        mmcv.mkdir_or_exist(imgfile_prefix)\n        result_files = []\n        prog_bar = mmcv.ProgressBar(len(self))\n        for idx in range(len(self)):\n            result = results[idx]\n            if to_label_id:\n                result = self._convert_to_label_id(result)\n            filename = self.img_infos[idx]['filename']\n            basename = osp.splitext(osp.basename(filename))[0]\n\n            png_filename = osp.join(imgfile_prefix, f'{basename}.png')\n\n            output = Image.fromarray(result.astype(np.uint8)).convert('P')\n            import cityscapesscripts.helpers.labels as CSLabels\n            palette = np.zeros((len(CSLabels.id2label), 3), dtype=np.uint8)\n            for label_id, label in CSLabels.id2label.items():\n                palette[label_id] = label.color\n\n            output.putpalette(palette)\n            output.save(png_filename)\n            result_files.append(png_filename)\n            prog_bar.update()\n\n        return result_files\n\n    def format_results(self, results, imgfile_prefix=None, to_label_id=True):\n        \"\"\"Format the results into dir (standard format for Cityscapes\n        evaluation).\n\n        Args:\n            results (list): Testing results of the dataset.\n            imgfile_prefix (str | None): The prefix of images files. It\n                includes the file path and the prefix of filename, e.g.,\n                \"a/b/prefix\". If not specified, a temp file will be created.\n                Default: None.\n            to_label_id (bool): whether convert output to label_id for\n                submission. Default: False\n\n        Returns:\n            tuple: (result_files, tmp_dir), result_files is a list containing\n                the image paths, tmp_dir is the temporal directory created\n                for saving json/png files when img_prefix is not specified.\n        \"\"\"\n\n        assert isinstance(results, list), 'results must be a list'\n        assert len(results) == len(self), (\n            'The length of results is not equal to the dataset len: '\n            f'{len(results)} != {len(self)}')\n\n        if imgfile_prefix is None:\n            tmp_dir = tempfile.TemporaryDirectory()\n            imgfile_prefix = tmp_dir.name\n        else:\n            tmp_dir = None\n        result_files = self.results2img(results, imgfile_prefix, to_label_id)\n\n        return result_files, tmp_dir\n\n    def evaluate(self,\n                 results,\n                 metric='mIoU',\n                 logger=None,\n                 imgfile_prefix=None,\n                 efficient_test=False):\n        \"\"\"Evaluation in Cityscapes/default protocol.\n\n        Args:\n            results (list): Testing results of the dataset.\n            metric (str | list[str]): Metrics to be evaluated.\n            logger (logging.Logger | None | str): Logger used for printing\n                related information during evaluation. Default: None.\n            imgfile_prefix (str | None): The prefix of output image file,\n                for cityscapes evaluation only. It includes the file path and\n                the prefix of filename, e.g., \"a/b/prefix\".\n                If results are evaluated with cityscapes protocol, it would be\n                the prefix of output png files. The output files would be\n                png images under folder \"a/b/prefix/xxx.png\", where \"xxx\" is\n                the image name of cityscapes. If not specified, a temp file\n                will be created for evaluation.\n                Default: None.\n\n        Returns:\n            dict[str, float]: Cityscapes/default metrics.\n        \"\"\"\n\n        eval_results = dict()\n        metrics = metric.copy() if isinstance(metric, list) else [metric]\n        if 'cityscapes' in metrics:\n            eval_results.update(\n                self._evaluate_cityscapes(results, logger, imgfile_prefix))\n            metrics.remove('cityscapes')\n        if len(metrics) > 0:\n            eval_results.update(\n                super(CityscapesDataset,\n                      self).evaluate(results, metrics, logger, efficient_test))\n\n        return eval_results\n\n    def _evaluate_cityscapes(self, results, logger, imgfile_prefix):\n        \"\"\"Evaluation in Cityscapes protocol.\n\n        Args:\n            results (list): Testing results of the dataset.\n            logger (logging.Logger | str | None): Logger used for printing\n                related information during evaluation. Default: None.\n            imgfile_prefix (str | None): The prefix of output image file\n\n        Returns:\n            dict[str: float]: Cityscapes evaluation results.\n        \"\"\"\n        try:\n            import cityscapesscripts.evaluation.evalPixelLevelSemanticLabeling as CSEval  # noqa\n        except ImportError:\n            raise ImportError('Please run \"pip install cityscapesscripts\" to '\n                              'install cityscapesscripts first.')\n        msg = 'Evaluating in Cityscapes style'\n        if logger is None:\n            msg = '\\n' + msg\n        print_log(msg, logger=logger)\n\n        result_files, tmp_dir = self.format_results(results, imgfile_prefix)\n\n        if tmp_dir is None:\n            result_dir = imgfile_prefix\n        else:\n            result_dir = tmp_dir.name\n\n        eval_results = dict()\n        print_log(f'Evaluating results under {result_dir} ...', logger=logger)\n\n        CSEval.args.evalInstLevelScore = True\n        CSEval.args.predictionPath = osp.abspath(result_dir)\n        CSEval.args.evalPixelAccuracy = True\n        CSEval.args.JSONOutput = False\n\n        seg_map_list = []\n        pred_list = []\n\n        # when evaluating with official cityscapesscripts,\n        # **_gtFine_labelIds.png is used\n        for seg_map in mmcv.scandir(\n                self.ann_dir, 'gtFine_labelIds.png', recursive=True):\n            seg_map_list.append(osp.join(self.ann_dir, seg_map))\n            pred_list.append(CSEval.getPrediction(CSEval.args, seg_map))\n\n        eval_results.update(\n            CSEval.evaluateImgLists(pred_list, seg_map_list, CSEval.args))\n\n        if tmp_dir is not None:\n            tmp_dir.cleanup()\n\n        return eval_results\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/mmseg/datasets/coco_stuff.py",
    "content": "from .builder import DATASETS\nfrom .custom import CustomDataset\n\n\n@DATASETS.register_module()\nclass COCOStuffDataset(CustomDataset):\n    \"\"\"COCO-Stuff dataset.\n    In segmentation map annotation for COCO-Stuff, Train-IDs of the 10k version\n    are from 1 to 171, where 0 is the ignore index, and Train-ID of COCO Stuff\n    164k is from 0 to 170, where 255 is the ignore index. So, they are all 171\n    semantic categories. ``reduce_zero_label`` is set to True and False for the\n    10k and 164k versions, respectively. The ``img_suffix`` is fixed to '.jpg',\n    and ``seg_map_suffix`` is fixed to '.png'.\n    \"\"\"\n    CLASSES = (\n        'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train',\n        'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign',\n        'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep',\n        'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella',\n        'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard',\n        'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard',\n        'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork',\n        'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange',\n        'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair',\n        'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv',\n        'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave',\n        'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase',\n        'scissors', 'teddy bear', 'hair drier', 'toothbrush', 'banner',\n        'blanket', 'branch', 'bridge', 'building-other', 'bush', 'cabinet',\n        'cage', 'cardboard', 'carpet', 'ceiling-other', 'ceiling-tile',\n        'cloth', 'clothes', 'clouds', 'counter', 'cupboard', 'curtain',\n        'desk-stuff', 'dirt', 'door-stuff', 'fence', 'floor-marble',\n        'floor-other', 'floor-stone', 'floor-tile', 'floor-wood',\n        'flower', 'fog', 'food-other', 'fruit', 'furniture-other', 'grass',\n        'gravel', 'ground-other', 'hill', 'house', 'leaves', 'light', 'mat',\n        'metal', 'mirror-stuff', 'moss', 'mountain', 'mud', 'napkin', 'net',\n        'paper', 'pavement', 'pillow', 'plant-other', 'plastic', 'platform',\n        'playingfield', 'railing', 'railroad', 'river', 'road', 'rock', 'roof',\n        'rug', 'salad', 'sand', 'sea', 'shelf', 'sky-other', 'skyscraper',\n        'snow', 'solid-other', 'stairs', 'stone', 'straw', 'structural-other',\n        'table', 'tent', 'textile-other', 'towel', 'tree', 'vegetable',\n        'wall-brick', 'wall-concrete', 'wall-other', 'wall-panel',\n        'wall-stone', 'wall-tile', 'wall-wood', 'water-other', 'waterdrops',\n        'window-blind', 'window-other', 'wood')\n\n    PALETTE = [[0, 192, 64], [0, 192, 64], [0, 64, 96], [128, 192, 192],\n               [0, 64, 64], [0, 192, 224], [0, 192, 192], [128, 192, 64],\n               [0, 192, 96], [128, 192, 64], [128, 32, 192], [0, 0, 224],\n               [0, 0, 64], [0, 160, 192], [128, 0, 96], [128, 0, 192],\n               [0, 32, 192], [128, 128, 224], [0, 0, 192], [128, 160, 192],\n               [128, 128, 0], [128, 0, 32], [128, 32, 0], [128, 0, 128],\n               [64, 128, 32], [0, 160, 0], [0, 0, 0], [192, 128, 160],\n               [0, 32, 0], [0, 128, 128], [64, 128, 160], [128, 160, 0],\n               [0, 128, 0], [192, 128, 32], [128, 96, 128], [0, 0, 128],\n               [64, 0, 32], [0, 224, 128], [128, 0, 0], [192, 0, 160],\n               [0, 96, 128], [128, 128, 128], [64, 0, 160], [128, 224, 128],\n               [128, 128, 64], [192, 0, 32], [128, 96, 0], [128, 0, 192],\n               [0, 128, 32], [64, 224, 0], [0, 0, 64], [128, 128, 160],\n               [64, 96, 0], [0, 128, 192], [0, 128, 160], [192, 224, 0],\n               [0, 128, 64], [128, 128, 32], [192, 32, 128], [0, 64, 192],\n               [0, 0, 32], [64, 160, 128], [128, 64, 64], [128, 0, 160],\n               [64, 32, 128], [128, 192, 192], [0, 0, 160], [192, 160, 128],\n               [128, 192, 0], [128, 0, 96], [192, 32, 0], [128, 64, 128],\n               [64, 128, 96], [64, 160, 0], [0, 64, 0], [192, 128, 224],\n               [64, 32, 0], [0, 192, 128], [64, 128, 224], [192, 160, 0],\n               [0, 192, 0], [192, 128, 96], [192, 96, 128], [0, 64, 128],\n               [64, 0, 96], [64, 224, 128], [128, 64, 0], [192, 0, 224],\n               [64, 96, 128], [128, 192, 128], [64, 0, 224], [192, 224, 128],\n               [128, 192, 64], [192, 0, 96], [192, 96, 0], [128, 64, 192],\n               [0, 128, 96], [0, 224, 0], [64, 64, 64], [128, 128, 224],\n               [0, 96, 0], [64, 192, 192], [0, 128, 224], [128, 224, 0],\n               [64, 192, 64], [128, 128, 96], [128, 32, 128], [64, 0, 192],\n               [0, 64, 96], [0, 160, 128], [192, 0, 64], [128, 64, 224],\n               [0, 32, 128], [192, 128, 192], [0, 64, 224], [128, 160, 128],\n               [192, 128, 0], [128, 64, 32], [128, 32, 64], [192, 0, 128],\n               [64, 192, 32], [0, 160, 64], [64, 0, 0], [192, 192, 160],\n               [0, 32, 64], [64, 128, 128], [64, 192, 160], [128, 160, 64],\n               [64, 128, 0], [192, 192, 32], [128, 96, 192], [64, 0, 128],\n               [64, 64, 32], [0, 224, 192], [192, 0, 0], [192, 64, 160],\n               [0, 96, 192], [192, 128, 128], [64, 64, 160], [128, 224, 192],\n               [192, 128, 64], [192, 64, 32], [128, 96, 64], [192, 0, 192],\n               [0, 192, 32], [64, 224, 64], [64, 0, 64], [128, 192, 160],\n               [64, 96, 64], [64, 128, 192], [0, 192, 160], [192, 224, 64],\n               [64, 128, 64], [128, 192, 32], [192, 32, 192], [64, 64, 192],\n               [0, 64, 32], [64, 160, 192], [192, 64, 64], [128, 64, 160],\n               [64, 32, 192], [192, 192, 192], [0, 64, 160], [192, 160, 192],\n               [192, 192, 0], [128, 64, 96], [192, 32, 64], [192, 64, 128],\n               [64, 192, 96], [64, 160, 64], [64, 64, 0]]\n\n    def __init__(self, **kwargs):\n        super(COCOStuffDataset, self).__init__(\n            img_suffix='.jpg', seg_map_suffix='_labelTrainIds.png', **kwargs)"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/mmseg/datasets/custom.py",
    "content": "import os\nimport os.path as osp\nfrom functools import reduce\n\nimport mmcv\nimport numpy as np\nfrom mmcv.utils import print_log\nfrom terminaltables import AsciiTable\nfrom torch.utils.data import Dataset\n\nfrom mmseg.core import eval_metrics\nfrom mmseg.utils import get_root_logger\nfrom .builder import DATASETS\nfrom .pipelines import Compose\n\n\n@DATASETS.register_module()\nclass CustomDataset(Dataset):\n    \"\"\"Custom dataset for semantic segmentation. An example of file structure\n    is as followed.\n\n    .. code-block:: none\n\n        ├── data\n        │   ├── my_dataset\n        │   │   ├── img_dir\n        │   │   │   ├── train\n        │   │   │   │   ├── xxx{img_suffix}\n        │   │   │   │   ├── yyy{img_suffix}\n        │   │   │   │   ├── zzz{img_suffix}\n        │   │   │   ├── val\n        │   │   ├── ann_dir\n        │   │   │   ├── train\n        │   │   │   │   ├── xxx{seg_map_suffix}\n        │   │   │   │   ├── yyy{seg_map_suffix}\n        │   │   │   │   ├── zzz{seg_map_suffix}\n        │   │   │   ├── val\n\n    The img/gt_semantic_seg pair of CustomDataset should be of the same\n    except suffix. A valid img/gt_semantic_seg filename pair should be like\n    ``xxx{img_suffix}`` and ``xxx{seg_map_suffix}`` (extension is also included\n    in the suffix). If split is given, then ``xxx`` is specified in txt file.\n    Otherwise, all files in ``img_dir/``and ``ann_dir`` will be loaded.\n    Please refer to ``docs/tutorials/new_dataset.md`` for more details.\n\n\n    Args:\n        pipeline (list[dict]): Processing pipeline\n        img_dir (str): Path to image directory\n        img_suffix (str): Suffix of images. Default: '.jpg'\n        ann_dir (str, optional): Path to annotation directory. Default: None\n        seg_map_suffix (str): Suffix of segmentation maps. Default: '.png'\n        split (str, optional): Split txt file. If split is specified, only\n            file with suffix in the splits will be loaded. Otherwise, all\n            images in img_dir/ann_dir will be loaded. Default: None\n        data_root (str, optional): Data root for img_dir/ann_dir. Default:\n            None.\n        test_mode (bool): If test_mode=True, gt wouldn't be loaded.\n        ignore_index (int): The label index to be ignored. Default: 255\n        reduce_zero_label (bool): Whether to mark label zero as ignored.\n            Default: False\n        classes (str | Sequence[str], optional): Specify classes to load.\n            If is None, ``cls.CLASSES`` will be used. Default: None.\n        palette (Sequence[Sequence[int]]] | np.ndarray | None):\n            The palette of segmentation map. If None is given, and\n            self.PALETTE is None, random palette will be generated.\n            Default: None\n    \"\"\"\n\n    CLASSES = None\n\n    PALETTE = None\n\n    def __init__(self,\n                 pipeline,\n                 img_dir,\n                 img_suffix='.jpg',\n                 ann_dir=None,\n                 seg_map_suffix='.png',\n                 split=None,\n                 data_root=None,\n                 test_mode=False,\n                 ignore_index=255,\n                 reduce_zero_label=False,\n                 classes=None,\n                 palette=None):\n        self.pipeline = Compose(pipeline)\n        self.img_dir = img_dir\n        self.img_suffix = img_suffix\n        self.ann_dir = ann_dir\n        self.seg_map_suffix = seg_map_suffix\n        self.split = split\n        self.data_root = data_root\n        self.test_mode = test_mode\n        self.ignore_index = ignore_index\n        self.reduce_zero_label = reduce_zero_label\n        self.label_map = None\n        self.CLASSES, self.PALETTE = self.get_classes_and_palette(\n            classes, palette)\n\n        # join paths if data_root is specified\n        if self.data_root is not None:\n            if not osp.isabs(self.img_dir):\n                self.img_dir = osp.join(self.data_root, self.img_dir)\n            if not (self.ann_dir is None or osp.isabs(self.ann_dir)):\n                self.ann_dir = osp.join(self.data_root, self.ann_dir)\n            if not (self.split is None or osp.isabs(self.split)):\n                self.split = osp.join(self.data_root, self.split)\n\n        # load annotations\n        self.img_infos = self.load_annotations(self.img_dir, self.img_suffix,\n                                               self.ann_dir,\n                                               self.seg_map_suffix, self.split)\n\n    def __len__(self):\n        \"\"\"Total number of samples of data.\"\"\"\n        return len(self.img_infos)\n\n    def load_annotations(self, img_dir, img_suffix, ann_dir, seg_map_suffix,\n                         split):\n        \"\"\"Load annotation from directory.\n\n        Args:\n            img_dir (str): Path to image directory\n            img_suffix (str): Suffix of images.\n            ann_dir (str|None): Path to annotation directory.\n            seg_map_suffix (str|None): Suffix of segmentation maps.\n            split (str|None): Split txt file. If split is specified, only file\n                with suffix in the splits will be loaded. Otherwise, all images\n                in img_dir/ann_dir will be loaded. Default: None\n\n        Returns:\n            list[dict]: All image info of dataset.\n        \"\"\"\n\n        img_infos = []\n        if split is not None:\n            with open(split) as f:\n                for line in f:\n                    img_name = line.strip()\n                    img_info = dict(filename=img_name + img_suffix)\n                    if ann_dir is not None:\n                        seg_map = img_name + seg_map_suffix\n                        img_info['ann'] = dict(seg_map=seg_map)\n                    img_infos.append(img_info)\n        else:\n            for img in mmcv.scandir(img_dir, img_suffix, recursive=True):\n                img_info = dict(filename=img)\n                if ann_dir is not None:\n                    seg_map = img.replace(img_suffix, seg_map_suffix)\n                    img_info['ann'] = dict(seg_map=seg_map)\n                img_infos.append(img_info)\n\n        print_log(f'Loaded {len(img_infos)} images', logger=get_root_logger())\n        return img_infos\n\n    def get_ann_info(self, idx):\n        \"\"\"Get annotation by index.\n\n        Args:\n            idx (int): Index of data.\n\n        Returns:\n            dict: Annotation info of specified index.\n        \"\"\"\n\n        return self.img_infos[idx]['ann']\n\n    def pre_pipeline(self, results):\n        \"\"\"Prepare results dict for pipeline.\"\"\"\n        results['seg_fields'] = []\n        results['img_prefix'] = self.img_dir\n        results['seg_prefix'] = self.ann_dir\n        if self.custom_classes:\n            results['label_map'] = self.label_map\n\n    def __getitem__(self, idx):\n        \"\"\"Get training/test data after pipeline.\n\n        Args:\n            idx (int): Index of data.\n\n        Returns:\n            dict: Training/test data (with annotation if `test_mode` is set\n                False).\n        \"\"\"\n\n        if self.test_mode:\n            return self.prepare_test_img(idx)\n        else:\n            return self.prepare_train_img(idx)\n\n    def prepare_train_img(self, idx):\n        \"\"\"Get training data and annotations after pipeline.\n\n        Args:\n            idx (int): Index of data.\n\n        Returns:\n            dict: Training data and annotation after pipeline with new keys\n                introduced by pipeline.\n        \"\"\"\n\n        img_info = self.img_infos[idx]\n        ann_info = self.get_ann_info(idx)\n        results = dict(img_info=img_info, ann_info=ann_info)\n        self.pre_pipeline(results)\n        return self.pipeline(results)\n\n    def prepare_test_img(self, idx):\n        \"\"\"Get testing data after pipeline.\n\n        Args:\n            idx (int): Index of data.\n\n        Returns:\n            dict: Testing data after pipeline with new keys intorduced by\n                piepline.\n        \"\"\"\n\n        img_info = self.img_infos[idx]\n        results = dict(img_info=img_info)\n        self.pre_pipeline(results)\n        return self.pipeline(results)\n\n    def format_results(self, results, **kwargs):\n        \"\"\"Place holder to format result to dataset specific output.\"\"\"\n        pass\n\n    def get_gt_seg_maps(self, efficient_test=False):\n        \"\"\"Get ground truth segmentation maps for evaluation.\"\"\"\n        gt_seg_maps = []\n        for img_info in self.img_infos:\n            seg_map = osp.join(self.ann_dir, img_info['ann']['seg_map'])\n            if efficient_test:\n                gt_seg_map = seg_map\n            else:\n                gt_seg_map = mmcv.imread(\n                    seg_map, flag='unchanged', backend='pillow')\n            gt_seg_maps.append(gt_seg_map)\n        return gt_seg_maps\n\n    def get_classes_and_palette(self, classes=None, palette=None):\n        \"\"\"Get class names of current dataset.\n\n        Args:\n            classes (Sequence[str] | str | None): If classes is None, use\n                default CLASSES defined by builtin dataset. If classes is a\n                string, take it as a file name. The file contains the name of\n                classes where each line contains one class name. If classes is\n                a tuple or list, override the CLASSES defined by the dataset.\n            palette (Sequence[Sequence[int]]] | np.ndarray | None):\n                The palette of segmentation map. If None is given, random\n                palette will be generated. Default: None\n        \"\"\"\n        if classes is None:\n            self.custom_classes = False\n            return self.CLASSES, self.PALETTE\n\n        self.custom_classes = True\n        if isinstance(classes, str):\n            # take it as a file path\n            class_names = mmcv.list_from_file(classes)\n        elif isinstance(classes, (tuple, list)):\n            class_names = classes\n        else:\n            raise ValueError(f'Unsupported type {type(classes)} of classes.')\n\n        if self.CLASSES:\n            if not set(classes).issubset(self.CLASSES):\n                raise ValueError('classes is not a subset of CLASSES.')\n\n            # dictionary, its keys are the old label ids and its values\n            # are the new label ids.\n            # used for changing pixel labels in load_annotations.\n            self.label_map = {}\n            for i, c in enumerate(self.CLASSES):\n                if c not in class_names:\n                    self.label_map[i] = -1\n                else:\n                    self.label_map[i] = classes.index(c)\n\n        palette = self.get_palette_for_custom_classes(class_names, palette)\n\n        return class_names, palette\n\n    def get_palette_for_custom_classes(self, class_names, palette=None):\n\n        if self.label_map is not None:\n            # return subset of palette\n            palette = []\n            for old_id, new_id in sorted(\n                    self.label_map.items(), key=lambda x: x[1]):\n                if new_id != -1:\n                    palette.append(self.PALETTE[old_id])\n            palette = type(self.PALETTE)(palette)\n\n        elif palette is None:\n            if self.PALETTE is None:\n                palette = np.random.randint(0, 255, size=(len(class_names), 3))\n            else:\n                palette = self.PALETTE\n\n        return palette\n\n    def evaluate(self,\n                 results,\n                 metric='mIoU',\n                 logger=None,\n                 efficient_test=False,\n                 **kwargs):\n        \"\"\"Evaluate the dataset.\n\n        Args:\n            results (list): Testing results of the dataset.\n            metric (str | list[str]): Metrics to be evaluated. 'mIoU' and\n                'mDice' are supported.\n            logger (logging.Logger | None | str): Logger used for printing\n                related information during evaluation. Default: None.\n\n        Returns:\n            dict[str, float]: Default metrics.\n        \"\"\"\n\n        if isinstance(metric, str):\n            metric = [metric]\n        allowed_metrics = ['mIoU', 'mDice']\n        if not set(metric).issubset(set(allowed_metrics)):\n            raise KeyError('metric {} is not supported'.format(metric))\n        eval_results = {}\n        gt_seg_maps = self.get_gt_seg_maps(efficient_test)\n        if self.CLASSES is None:\n            num_classes = len(\n                reduce(np.union1d, [np.unique(_) for _ in gt_seg_maps]))\n        else:\n            num_classes = len(self.CLASSES)\n        ret_metrics = eval_metrics(\n            results,\n            gt_seg_maps,\n            num_classes,\n            self.ignore_index,\n            metric,\n            label_map=self.label_map,\n            reduce_zero_label=self.reduce_zero_label)\n        class_table_data = [['Class'] + [m[1:] for m in metric] + ['Acc']]\n        if self.CLASSES is None:\n            class_names = tuple(range(num_classes))\n        else:\n            class_names = self.CLASSES\n        ret_metrics_round = [\n            np.round(ret_metric * 100, 2) for ret_metric in ret_metrics\n        ]\n        for i in range(num_classes):\n            class_table_data.append([class_names[i]] +\n                                    [m[i] for m in ret_metrics_round[2:]] +\n                                    [ret_metrics_round[1][i]])\n        summary_table_data = [['Scope'] +\n                              ['m' + head\n                               for head in class_table_data[0][1:]] + ['aAcc']]\n        ret_metrics_mean = [\n            np.round(np.nanmean(ret_metric) * 100, 2)\n            for ret_metric in ret_metrics\n        ]\n        summary_table_data.append(['global'] + ret_metrics_mean[2:] +\n                                  [ret_metrics_mean[1]] +\n                                  [ret_metrics_mean[0]])\n        print_log('per class results:', logger)\n        table = AsciiTable(class_table_data)\n        print_log('\\n' + table.table, logger=logger)\n        print_log('Summary:', logger)\n        table = AsciiTable(summary_table_data)\n        print_log('\\n' + table.table, logger=logger)\n\n        for i in range(1, len(summary_table_data[0])):\n            eval_results[summary_table_data[0]\n                         [i]] = summary_table_data[1][i] / 100.0\n        if mmcv.is_list_of(results, str):\n            for file_name in results:\n                os.remove(file_name)\n        return eval_results\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/mmseg/datasets/dataset_wrappers.py",
    "content": "from torch.utils.data.dataset import ConcatDataset as _ConcatDataset\n\nfrom .builder import DATASETS\n\n\n@DATASETS.register_module()\nclass ConcatDataset(_ConcatDataset):\n    \"\"\"A wrapper of concatenated dataset.\n\n    Same as :obj:`torch.utils.data.dataset.ConcatDataset`, but\n    concat the group flag for image aspect ratio.\n\n    Args:\n        datasets (list[:obj:`Dataset`]): A list of datasets.\n    \"\"\"\n\n    def __init__(self, datasets):\n        super(ConcatDataset, self).__init__(datasets)\n        self.CLASSES = datasets[0].CLASSES\n        self.PALETTE = datasets[0].PALETTE\n\n\n@DATASETS.register_module()\nclass RepeatDataset(object):\n    \"\"\"A wrapper of repeated dataset.\n\n    The length of repeated dataset will be `times` larger than the original\n    dataset. This is useful when the data loading time is long but the dataset\n    is small. Using RepeatDataset can reduce the data loading time between\n    epochs.\n\n    Args:\n        dataset (:obj:`Dataset`): The dataset to be repeated.\n        times (int): Repeat times.\n    \"\"\"\n\n    def __init__(self, dataset, times):\n        self.dataset = dataset\n        self.times = times\n        self.CLASSES = dataset.CLASSES\n        self.PALETTE = dataset.PALETTE\n        self._ori_len = len(self.dataset)\n\n    def __getitem__(self, idx):\n        \"\"\"Get item from original dataset.\"\"\"\n        return self.dataset[idx % self._ori_len]\n\n    def __len__(self):\n        \"\"\"The length is multiplied by ``times``\"\"\"\n        return self.times * self._ori_len\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/mmseg/datasets/drive.py",
    "content": "import os.path as osp\n\nfrom .builder import DATASETS\nfrom .custom import CustomDataset\n\n\n@DATASETS.register_module()\nclass DRIVEDataset(CustomDataset):\n    \"\"\"DRIVE dataset.\n\n    In segmentation map annotation for DRIVE, 0 stands for background, which is\n    included in 2 categories. ``reduce_zero_label`` is fixed to False. The\n    ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to\n    '_manual1.png'.\n    \"\"\"\n\n    CLASSES = ('background', 'vessel')\n\n    PALETTE = [[120, 120, 120], [6, 230, 230]]\n\n    def __init__(self, **kwargs):\n        super(DRIVEDataset, self).__init__(\n            img_suffix='.png',\n            seg_map_suffix='_manual1.png',\n            reduce_zero_label=False,\n            **kwargs)\n        assert osp.exists(self.img_dir)\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/mmseg/datasets/hrf.py",
    "content": "import os.path as osp\n\nfrom .builder import DATASETS\nfrom .custom import CustomDataset\n\n\n@DATASETS.register_module()\nclass HRFDataset(CustomDataset):\n    \"\"\"HRF dataset.\n\n    In segmentation map annotation for HRF, 0 stands for background, which is\n    included in 2 categories. ``reduce_zero_label`` is fixed to False. The\n    ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to\n    '.png'.\n    \"\"\"\n\n    CLASSES = ('background', 'vessel')\n\n    PALETTE = [[120, 120, 120], [6, 230, 230]]\n\n    def __init__(self, **kwargs):\n        super(HRFDataset, self).__init__(\n            img_suffix='.png',\n            seg_map_suffix='.png',\n            reduce_zero_label=False,\n            **kwargs)\n        assert osp.exists(self.img_dir)\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/mmseg/datasets/pascal_context.py",
    "content": "import os.path as osp\n\nfrom .builder import DATASETS\nfrom .custom import CustomDataset\n\n\n@DATASETS.register_module()\nclass PascalContextDataset(CustomDataset):\n    \"\"\"PascalContext dataset.\n\n    In segmentation map annotation for PascalContext, 0 stands for background,\n    which is included in 60 categories. ``reduce_zero_label`` is fixed to\n    False. The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is\n    fixed to '.png'.\n\n    Args:\n        split (str): Split txt file for PascalContext.\n    \"\"\"\n\n    CLASSES = ('background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle',\n               'bus', 'car', 'cat', 'chair', 'cow', 'table', 'dog', 'horse',\n               'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train',\n               'tvmonitor', 'bag', 'bed', 'bench', 'book', 'building',\n               'cabinet', 'ceiling', 'cloth', 'computer', 'cup', 'door',\n               'fence', 'floor', 'flower', 'food', 'grass', 'ground',\n               'keyboard', 'light', 'mountain', 'mouse', 'curtain', 'platform',\n               'sign', 'plate', 'road', 'rock', 'shelves', 'sidewalk', 'sky',\n               'snow', 'bedclothes', 'track', 'tree', 'truck', 'wall', 'water',\n               'window', 'wood')\n\n    PALETTE = [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],\n               [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],\n               [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],\n               [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],\n               [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3],\n               [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],\n               [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220],\n               [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],\n               [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255],\n               [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],\n               [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153],\n               [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],\n               [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],\n               [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],\n               [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255]]\n\n    def __init__(self, split, **kwargs):\n        super(PascalContextDataset, self).__init__(\n            img_suffix='.jpg',\n            seg_map_suffix='.png',\n            split=split,\n            reduce_zero_label=False,\n            **kwargs)\n        assert osp.exists(self.img_dir) and self.split is not None\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/mmseg/datasets/pipelines/__init__.py",
    "content": "from .compose import Compose\nfrom .formating import (Collect, ImageToTensor, ToDataContainer, ToTensor,\n                        Transpose, to_tensor)\nfrom .loading import LoadAnnotations, LoadImageFromFile\nfrom .test_time_aug import MultiScaleFlipAug\nfrom .transforms import (CLAHE, AdjustGamma, Normalize, Pad,\n                         PhotoMetricDistortion, RandomCrop, RandomFlip,\n                         RandomRotate, Rerange, Resize, RGB2Gray, SegRescale)\n\n__all__ = [\n    'Compose', 'to_tensor', 'ToTensor', 'ImageToTensor', 'ToDataContainer',\n    'Transpose', 'Collect', 'LoadAnnotations', 'LoadImageFromFile',\n    'MultiScaleFlipAug', 'Resize', 'RandomFlip', 'Pad', 'RandomCrop',\n    'Normalize', 'SegRescale', 'PhotoMetricDistortion', 'RandomRotate',\n    'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray'\n]\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/mmseg/datasets/pipelines/compose.py",
    "content": "import collections\n\nfrom mmcv.utils import build_from_cfg\n\nfrom ..builder import PIPELINES\n\n\n@PIPELINES.register_module()\nclass Compose(object):\n    \"\"\"Compose multiple transforms sequentially.\n\n    Args:\n        transforms (Sequence[dict | callable]): Sequence of transform object or\n            config dict to be composed.\n    \"\"\"\n\n    def __init__(self, transforms):\n        assert isinstance(transforms, collections.abc.Sequence)\n        self.transforms = []\n        for transform in transforms:\n            if isinstance(transform, dict):\n                transform = build_from_cfg(transform, PIPELINES)\n                self.transforms.append(transform)\n            elif callable(transform):\n                self.transforms.append(transform)\n            else:\n                raise TypeError('transform must be callable or a dict')\n\n    def __call__(self, data):\n        \"\"\"Call function to apply transforms sequentially.\n\n        Args:\n            data (dict): A result dict contains the data to transform.\n\n        Returns:\n           dict: Transformed data.\n        \"\"\"\n\n        for t in self.transforms:\n            data = t(data)\n            if data is None:\n                return None\n        return data\n\n    def __repr__(self):\n        format_string = self.__class__.__name__ + '('\n        for t in self.transforms:\n            format_string += '\\n'\n            format_string += f'    {t}'\n        format_string += '\\n)'\n        return format_string\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/mmseg/datasets/pipelines/formating.py",
    "content": "from collections.abc import Sequence\n\nimport mmcv\nimport numpy as np\nimport torch\nfrom mmcv.parallel import DataContainer as DC\n\nfrom ..builder import PIPELINES\n\n\ndef to_tensor(data):\n    \"\"\"Convert objects of various python types to :obj:`torch.Tensor`.\n\n    Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`,\n    :class:`Sequence`, :class:`int` and :class:`float`.\n\n    Args:\n        data (torch.Tensor | numpy.ndarray | Sequence | int | float): Data to\n            be converted.\n    \"\"\"\n\n    if isinstance(data, torch.Tensor):\n        return data\n    elif isinstance(data, np.ndarray):\n        return torch.from_numpy(data)\n    elif isinstance(data, Sequence) and not mmcv.is_str(data):\n        return torch.tensor(data)\n    elif isinstance(data, int):\n        return torch.LongTensor([data])\n    elif isinstance(data, float):\n        return torch.FloatTensor([data])\n    else:\n        raise TypeError(f'type {type(data)} cannot be converted to tensor.')\n\n\n@PIPELINES.register_module()\nclass ToTensor(object):\n    \"\"\"Convert some results to :obj:`torch.Tensor` by given keys.\n\n    Args:\n        keys (Sequence[str]): Keys that need to be converted to Tensor.\n    \"\"\"\n\n    def __init__(self, keys):\n        self.keys = keys\n\n    def __call__(self, results):\n        \"\"\"Call function to convert data in results to :obj:`torch.Tensor`.\n\n        Args:\n            results (dict): Result dict contains the data to convert.\n\n        Returns:\n            dict: The result dict contains the data converted\n                to :obj:`torch.Tensor`.\n        \"\"\"\n\n        for key in self.keys:\n            results[key] = to_tensor(results[key])\n        return results\n\n    def __repr__(self):\n        return self.__class__.__name__ + f'(keys={self.keys})'\n\n\n@PIPELINES.register_module()\nclass ImageToTensor(object):\n    \"\"\"Convert image to :obj:`torch.Tensor` by given keys.\n\n    The dimension order of input image is (H, W, C). The pipeline will convert\n    it to (C, H, W). If only 2 dimension (H, W) is given, the output would be\n    (1, H, W).\n\n    Args:\n        keys (Sequence[str]): Key of images to be converted to Tensor.\n    \"\"\"\n\n    def __init__(self, keys):\n        self.keys = keys\n\n    def __call__(self, results):\n        \"\"\"Call function to convert image in results to :obj:`torch.Tensor` and\n        transpose the channel order.\n\n        Args:\n            results (dict): Result dict contains the image data to convert.\n\n        Returns:\n            dict: The result dict contains the image converted\n                to :obj:`torch.Tensor` and transposed to (C, H, W) order.\n        \"\"\"\n\n        for key in self.keys:\n            img = results[key]\n            if len(img.shape) < 3:\n                img = np.expand_dims(img, -1)\n            results[key] = to_tensor(img.transpose(2, 0, 1))\n        return results\n\n    def __repr__(self):\n        return self.__class__.__name__ + f'(keys={self.keys})'\n\n\n@PIPELINES.register_module()\nclass Transpose(object):\n    \"\"\"Transpose some results by given keys.\n\n    Args:\n        keys (Sequence[str]): Keys of results to be transposed.\n        order (Sequence[int]): Order of transpose.\n    \"\"\"\n\n    def __init__(self, keys, order):\n        self.keys = keys\n        self.order = order\n\n    def __call__(self, results):\n        \"\"\"Call function to convert image in results to :obj:`torch.Tensor` and\n        transpose the channel order.\n\n        Args:\n            results (dict): Result dict contains the image data to convert.\n\n        Returns:\n            dict: The result dict contains the image converted\n                to :obj:`torch.Tensor` and transposed to (C, H, W) order.\n        \"\"\"\n\n        for key in self.keys:\n            results[key] = results[key].transpose(self.order)\n        return results\n\n    def __repr__(self):\n        return self.__class__.__name__ + \\\n               f'(keys={self.keys}, order={self.order})'\n\n\n@PIPELINES.register_module()\nclass ToDataContainer(object):\n    \"\"\"Convert results to :obj:`mmcv.DataContainer` by given fields.\n\n    Args:\n        fields (Sequence[dict]): Each field is a dict like\n            ``dict(key='xxx', **kwargs)``. The ``key`` in result will\n            be converted to :obj:`mmcv.DataContainer` with ``**kwargs``.\n            Default: ``(dict(key='img', stack=True),\n            dict(key='gt_semantic_seg'))``.\n    \"\"\"\n\n    def __init__(self,\n                 fields=(dict(key='img',\n                              stack=True), dict(key='gt_semantic_seg'))):\n        self.fields = fields\n\n    def __call__(self, results):\n        \"\"\"Call function to convert data in results to\n        :obj:`mmcv.DataContainer`.\n\n        Args:\n            results (dict): Result dict contains the data to convert.\n\n        Returns:\n            dict: The result dict contains the data converted to\n                :obj:`mmcv.DataContainer`.\n        \"\"\"\n\n        for field in self.fields:\n            field = field.copy()\n            key = field.pop('key')\n            results[key] = DC(results[key], **field)\n        return results\n\n    def __repr__(self):\n        return self.__class__.__name__ + f'(fields={self.fields})'\n\n\n@PIPELINES.register_module()\nclass DefaultFormatBundle(object):\n    \"\"\"Default formatting bundle.\n\n    It simplifies the pipeline of formatting common fields, including \"img\"\n    and \"gt_semantic_seg\". These fields are formatted as follows.\n\n    - img: (1)transpose, (2)to tensor, (3)to DataContainer (stack=True)\n    - gt_semantic_seg: (1)unsqueeze dim-0 (2)to tensor,\n                       (3)to DataContainer (stack=True)\n    \"\"\"\n\n    def __call__(self, results):\n        \"\"\"Call function to transform and format common fields in results.\n\n        Args:\n            results (dict): Result dict contains the data to convert.\n\n        Returns:\n            dict: The result dict contains the data that is formatted with\n                default bundle.\n        \"\"\"\n\n        if 'img' in results:\n            img = results['img']\n            if len(img.shape) < 3:\n                img = np.expand_dims(img, -1)\n            img = np.ascontiguousarray(img.transpose(2, 0, 1))\n            results['img'] = DC(to_tensor(img), stack=True)\n        if 'gt_semantic_seg' in results:\n            # convert to long\n            results['gt_semantic_seg'] = DC(\n                to_tensor(results['gt_semantic_seg'][None,\n                                                     ...].astype(np.int64)),\n                stack=True)\n        return results\n\n    def __repr__(self):\n        return self.__class__.__name__\n\n\n@PIPELINES.register_module()\nclass Collect(object):\n    \"\"\"Collect data from the loader relevant to the specific task.\n\n    This is usually the last stage of the data loader pipeline. Typically keys\n    is set to some subset of \"img\", \"gt_semantic_seg\".\n\n    The \"img_meta\" item is always populated.  The contents of the \"img_meta\"\n    dictionary depends on \"meta_keys\". By default this includes:\n\n        - \"img_shape\": shape of the image input to the network as a tuple\n            (h, w, c).  Note that images may be zero padded on the bottom/right\n            if the batch tensor is larger than this shape.\n\n        - \"scale_factor\": a float indicating the preprocessing scale\n\n        - \"flip\": a boolean indicating if image flip transform was used\n\n        - \"filename\": path to the image file\n\n        - \"ori_shape\": original shape of the image as a tuple (h, w, c)\n\n        - \"pad_shape\": image shape after padding\n\n        - \"img_norm_cfg\": a dict of normalization information:\n            - mean - per channel mean subtraction\n            - std - per channel std divisor\n            - to_rgb - bool indicating if bgr was converted to rgb\n\n    Args:\n        keys (Sequence[str]): Keys of results to be collected in ``data``.\n        meta_keys (Sequence[str], optional): Meta keys to be converted to\n            ``mmcv.DataContainer`` and collected in ``data[img_metas]``.\n            Default: ``('filename', 'ori_filename', 'ori_shape', 'img_shape',\n            'pad_shape', 'scale_factor', 'flip', 'flip_direction',\n            'img_norm_cfg')``\n    \"\"\"\n\n    def __init__(self,\n                 keys,\n                 meta_keys=('filename', 'ori_filename', 'ori_shape',\n                            'img_shape', 'pad_shape', 'scale_factor', 'flip',\n                            'flip_direction', 'img_norm_cfg')):\n        self.keys = keys\n        self.meta_keys = meta_keys\n\n    def __call__(self, results):\n        \"\"\"Call function to collect keys in results. The keys in ``meta_keys``\n        will be converted to :obj:mmcv.DataContainer.\n\n        Args:\n            results (dict): Result dict contains the data to collect.\n\n        Returns:\n            dict: The result dict contains the following keys\n                - keys in``self.keys``\n                - ``img_metas``\n        \"\"\"\n\n        data = {}\n        img_meta = {}\n        for key in self.meta_keys:\n            img_meta[key] = results[key]\n        data['img_metas'] = DC(img_meta, cpu_only=True)\n        for key in self.keys:\n            data[key] = results[key]\n        return data\n\n    def __repr__(self):\n        return self.__class__.__name__ + \\\n               f'(keys={self.keys}, meta_keys={self.meta_keys})'\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/mmseg/datasets/pipelines/loading.py",
    "content": "import os.path as osp\n\nimport mmcv\nimport numpy as np\n\nfrom ..builder import PIPELINES\n\n\n@PIPELINES.register_module()\nclass LoadImageFromFile(object):\n    \"\"\"Load an image from file.\n\n    Required keys are \"img_prefix\" and \"img_info\" (a dict that must contain the\n    key \"filename\"). Added or updated keys are \"filename\", \"img\", \"img_shape\",\n    \"ori_shape\" (same as `img_shape`), \"pad_shape\" (same as `img_shape`),\n    \"scale_factor\" (1.0) and \"img_norm_cfg\" (means=0 and stds=1).\n\n    Args:\n        to_float32 (bool): Whether to convert the loaded image to a float32\n            numpy array. If set to False, the loaded image is an uint8 array.\n            Defaults to False.\n        color_type (str): The flag argument for :func:`mmcv.imfrombytes`.\n            Defaults to 'color'.\n        file_client_args (dict): Arguments to instantiate a FileClient.\n            See :class:`mmcv.fileio.FileClient` for details.\n            Defaults to ``dict(backend='disk')``.\n        imdecode_backend (str): Backend for :func:`mmcv.imdecode`. Default:\n            'cv2'\n    \"\"\"\n\n    def __init__(self,\n                 to_float32=False,\n                 color_type='color',\n                 file_client_args=dict(backend='disk'),\n                 imdecode_backend='cv2'):\n        self.to_float32 = to_float32\n        self.color_type = color_type\n        self.file_client_args = file_client_args.copy()\n        self.file_client = None\n        self.imdecode_backend = imdecode_backend\n\n    def __call__(self, results):\n        \"\"\"Call functions to load image and get image meta information.\n\n        Args:\n            results (dict): Result dict from :obj:`mmseg.CustomDataset`.\n\n        Returns:\n            dict: The dict contains loaded image and meta information.\n        \"\"\"\n\n        if self.file_client is None:\n            self.file_client = mmcv.FileClient(**self.file_client_args)\n\n        if results.get('img_prefix') is not None:\n            filename = osp.join(results['img_prefix'],\n                                results['img_info']['filename'])\n        else:\n            filename = results['img_info']['filename']\n        img_bytes = self.file_client.get(filename)\n        img = mmcv.imfrombytes(\n            img_bytes, flag=self.color_type, backend=self.imdecode_backend)\n        if self.to_float32:\n            img = img.astype(np.float32)\n\n        results['filename'] = filename\n        results['ori_filename'] = results['img_info']['filename']\n        results['img'] = img\n        results['img_shape'] = img.shape\n        results['ori_shape'] = img.shape\n        # Set initial values for default meta_keys\n        results['pad_shape'] = img.shape\n        results['scale_factor'] = 1.0\n        num_channels = 1 if len(img.shape) < 3 else img.shape[2]\n        results['img_norm_cfg'] = dict(\n            mean=np.zeros(num_channels, dtype=np.float32),\n            std=np.ones(num_channels, dtype=np.float32),\n            to_rgb=False)\n        return results\n\n    def __repr__(self):\n        repr_str = self.__class__.__name__\n        repr_str += f'(to_float32={self.to_float32},'\n        repr_str += f\"color_type='{self.color_type}',\"\n        repr_str += f\"imdecode_backend='{self.imdecode_backend}')\"\n        return repr_str\n\n\n@PIPELINES.register_module()\nclass LoadAnnotations(object):\n    \"\"\"Load annotations for semantic segmentation.\n\n    Args:\n        reduce_zero_label (bool): Whether reduce all label value by 1.\n            Usually used for datasets where 0 is background label.\n            Default: False.\n        file_client_args (dict): Arguments to instantiate a FileClient.\n            See :class:`mmcv.fileio.FileClient` for details.\n            Defaults to ``dict(backend='disk')``.\n        imdecode_backend (str): Backend for :func:`mmcv.imdecode`. Default:\n            'pillow'\n    \"\"\"\n\n    def __init__(self,\n                 reduce_zero_label=False,\n                 file_client_args=dict(backend='disk'),\n                 imdecode_backend='pillow'):\n        self.reduce_zero_label = reduce_zero_label\n        self.file_client_args = file_client_args.copy()\n        self.file_client = None\n        self.imdecode_backend = imdecode_backend\n\n    def __call__(self, results):\n        \"\"\"Call function to load multiple types annotations.\n\n        Args:\n            results (dict): Result dict from :obj:`mmseg.CustomDataset`.\n\n        Returns:\n            dict: The dict contains loaded semantic segmentation annotations.\n        \"\"\"\n\n        if self.file_client is None:\n            self.file_client = mmcv.FileClient(**self.file_client_args)\n\n        if results.get('seg_prefix', None) is not None:\n            filename = osp.join(results['seg_prefix'],\n                                results['ann_info']['seg_map'])\n        else:\n            filename = results['ann_info']['seg_map']\n        img_bytes = self.file_client.get(filename)\n        gt_semantic_seg = mmcv.imfrombytes(\n            img_bytes, flag='unchanged',\n            backend=self.imdecode_backend).squeeze().astype(np.uint8)\n        # modify if custom classes\n        if results.get('label_map', None) is not None:\n            for old_id, new_id in results['label_map'].items():\n                gt_semantic_seg[gt_semantic_seg == old_id] = new_id\n        # reduce zero_label\n        if self.reduce_zero_label:\n            # avoid using underflow conversion\n            gt_semantic_seg[gt_semantic_seg == 0] = 255\n            gt_semantic_seg = gt_semantic_seg - 1\n            gt_semantic_seg[gt_semantic_seg == 254] = 255\n        results['gt_semantic_seg'] = gt_semantic_seg\n        results['seg_fields'].append('gt_semantic_seg')\n        return results\n\n    def __repr__(self):\n        repr_str = self.__class__.__name__\n        repr_str += f'(reduce_zero_label={self.reduce_zero_label},'\n        repr_str += f\"imdecode_backend='{self.imdecode_backend}')\"\n        return repr_str\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/mmseg/datasets/pipelines/test_time_aug.py",
    "content": "import warnings\n\nimport mmcv\n\nfrom ..builder import PIPELINES\nfrom .compose import Compose\n\n\n@PIPELINES.register_module()\nclass MultiScaleFlipAug(object):\n    \"\"\"Test-time augmentation with multiple scales and flipping.\n\n    An example configuration is as followed:\n\n    .. code-block::\n\n        img_scale=(2048, 1024),\n        img_ratios=[0.5, 1.0],\n        flip=True,\n        transforms=[\n            dict(type='Resize', keep_ratio=True),\n            dict(type='RandomFlip'),\n            dict(type='Normalize', **img_norm_cfg),\n            dict(type='Pad', size_divisor=32),\n            dict(type='ImageToTensor', keys=['img']),\n            dict(type='Collect', keys=['img']),\n        ]\n\n    After MultiScaleFLipAug with above configuration, the results are wrapped\n    into lists of the same length as followed:\n\n    .. code-block::\n\n        dict(\n            img=[...],\n            img_shape=[...],\n            scale=[(1024, 512), (1024, 512), (2048, 1024), (2048, 1024)]\n            flip=[False, True, False, True]\n            ...\n        )\n\n    Args:\n        transforms (list[dict]): Transforms to apply in each augmentation.\n        img_scale (None | tuple | list[tuple]): Images scales for resizing.\n        img_ratios (float | list[float]): Image ratios for resizing\n        flip (bool): Whether apply flip augmentation. Default: False.\n        flip_direction (str | list[str]): Flip augmentation directions,\n            options are \"horizontal\" and \"vertical\". If flip_direction is list,\n            multiple flip augmentations will be applied.\n            It has no effect when flip == False. Default: \"horizontal\".\n    \"\"\"\n\n    def __init__(self,\n                 transforms,\n                 img_scale,\n                 img_ratios=None,\n                 flip=False,\n                 flip_direction='horizontal'):\n        self.transforms = Compose(transforms)\n        if img_ratios is not None:\n            img_ratios = img_ratios if isinstance(img_ratios,\n                                                  list) else [img_ratios]\n            assert mmcv.is_list_of(img_ratios, float)\n        if img_scale is None:\n            # mode 1: given img_scale=None and a range of image ratio\n            self.img_scale = None\n            assert mmcv.is_list_of(img_ratios, float)\n        elif isinstance(img_scale, tuple) and mmcv.is_list_of(\n                img_ratios, float):\n            assert len(img_scale) == 2\n            # mode 2: given a scale and a range of image ratio\n            self.img_scale = [(int(img_scale[0] * ratio),\n                               int(img_scale[1] * ratio))\n                              for ratio in img_ratios]\n        else:\n            # mode 3: given multiple scales\n            self.img_scale = img_scale if isinstance(img_scale,\n                                                     list) else [img_scale]\n        assert mmcv.is_list_of(self.img_scale, tuple) or self.img_scale is None\n        self.flip = flip\n        self.img_ratios = img_ratios\n        self.flip_direction = flip_direction if isinstance(\n            flip_direction, list) else [flip_direction]\n        assert mmcv.is_list_of(self.flip_direction, str)\n        if not self.flip and self.flip_direction != ['horizontal']:\n            warnings.warn(\n                'flip_direction has no effect when flip is set to False')\n        if (self.flip\n                and not any([t['type'] == 'RandomFlip' for t in transforms])):\n            warnings.warn(\n                'flip has no effect when RandomFlip is not in transforms')\n\n    def __call__(self, results):\n        \"\"\"Call function to apply test time augment transforms on results.\n\n        Args:\n            results (dict): Result dict contains the data to transform.\n\n        Returns:\n           dict[str: list]: The augmented data, where each value is wrapped\n               into a list.\n        \"\"\"\n\n        aug_data = []\n        if self.img_scale is None and mmcv.is_list_of(self.img_ratios, float):\n            h, w = results['img'].shape[:2]\n            img_scale = [(int(w * ratio), int(h * ratio))\n                         for ratio in self.img_ratios]\n        else:\n            img_scale = self.img_scale\n        flip_aug = [False, True] if self.flip else [False]\n        for scale in img_scale:\n            for flip in flip_aug:\n                for direction in self.flip_direction:\n                    _results = results.copy()\n                    _results['scale'] = scale\n                    _results['flip'] = flip\n                    _results['flip_direction'] = direction\n                    data = self.transforms(_results)\n                    aug_data.append(data)\n        # list of dict to dict of list\n        aug_data_dict = {key: [] for key in aug_data[0]}\n        for data in aug_data:\n            for key, val in data.items():\n                aug_data_dict[key].append(val)\n        return aug_data_dict\n\n    def __repr__(self):\n        repr_str = self.__class__.__name__\n        repr_str += f'(transforms={self.transforms}, '\n        repr_str += f'img_scale={self.img_scale}, flip={self.flip})'\n        repr_str += f'flip_direction={self.flip_direction}'\n        return repr_str\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/mmseg/datasets/pipelines/transforms.py",
    "content": "import mmcv\nimport numpy as np\nfrom mmcv.utils import deprecated_api_warning, is_tuple_of\nfrom numpy import random\n\nfrom ..builder import PIPELINES\n\n\n@PIPELINES.register_module()\nclass Resize(object):\n    \"\"\"Resize images & seg.\n\n    This transform resizes the input image to some scale. If the input dict\n    contains the key \"scale\", then the scale in the input dict is used,\n    otherwise the specified scale in the init method is used.\n\n    ``img_scale`` can be Nong, a tuple (single-scale) or a list of tuple\n    (multi-scale). There are 4 multiscale modes:\n\n    - ``ratio_range is not None``:\n    1. When img_scale is None, img_scale is the shape of image in results\n        (img_scale = results['img'].shape[:2]) and the image is resized based\n        on the original size. (mode 1)\n    2. When img_scale is a tuple (single-scale), randomly sample a ratio from\n        the ratio range and multiply it with the image scale. (mode 2)\n\n    - ``ratio_range is None and multiscale_mode == \"range\"``: randomly sample a\n    scale from the a range. (mode 3)\n\n    - ``ratio_range is None and multiscale_mode == \"value\"``: randomly sample a\n    scale from multiple scales. (mode 4)\n\n    Args:\n        img_scale (tuple or list[tuple]): Images scales for resizing.\n        multiscale_mode (str): Either \"range\" or \"value\".\n        ratio_range (tuple[float]): (min_ratio, max_ratio)\n        keep_ratio (bool): Whether to keep the aspect ratio when resizing the\n            image.\n    \"\"\"\n\n    def __init__(self,\n                 img_scale=None,\n                 multiscale_mode='range',\n                 ratio_range=None,\n                 keep_ratio=True):\n        if img_scale is None:\n            self.img_scale = None\n        else:\n            if isinstance(img_scale, list):\n                self.img_scale = img_scale\n            else:\n                self.img_scale = [img_scale]\n            assert mmcv.is_list_of(self.img_scale, tuple)\n\n        if ratio_range is not None:\n            # mode 1: given img_scale=None and a range of image ratio\n            # mode 2: given a scale and a range of image ratio\n            assert self.img_scale is None or len(self.img_scale) == 1\n        else:\n            # mode 3 and 4: given multiple scales or a range of scales\n            assert multiscale_mode in ['value', 'range']\n\n        self.multiscale_mode = multiscale_mode\n        self.ratio_range = ratio_range\n        self.keep_ratio = keep_ratio\n\n    @staticmethod\n    def random_select(img_scales):\n        \"\"\"Randomly select an img_scale from given candidates.\n\n        Args:\n            img_scales (list[tuple]): Images scales for selection.\n\n        Returns:\n            (tuple, int): Returns a tuple ``(img_scale, scale_dix)``,\n                where ``img_scale`` is the selected image scale and\n                ``scale_idx`` is the selected index in the given candidates.\n        \"\"\"\n\n        assert mmcv.is_list_of(img_scales, tuple)\n        scale_idx = np.random.randint(len(img_scales))\n        img_scale = img_scales[scale_idx]\n        return img_scale, scale_idx\n\n    @staticmethod\n    def random_sample(img_scales):\n        \"\"\"Randomly sample an img_scale when ``multiscale_mode=='range'``.\n\n        Args:\n            img_scales (list[tuple]): Images scale range for sampling.\n                There must be two tuples in img_scales, which specify the lower\n                and uper bound of image scales.\n\n        Returns:\n            (tuple, None): Returns a tuple ``(img_scale, None)``, where\n                ``img_scale`` is sampled scale and None is just a placeholder\n                to be consistent with :func:`random_select`.\n        \"\"\"\n\n        assert mmcv.is_list_of(img_scales, tuple) and len(img_scales) == 2\n        img_scale_long = [max(s) for s in img_scales]\n        img_scale_short = [min(s) for s in img_scales]\n        long_edge = np.random.randint(\n            min(img_scale_long),\n            max(img_scale_long) + 1)\n        short_edge = np.random.randint(\n            min(img_scale_short),\n            max(img_scale_short) + 1)\n        img_scale = (long_edge, short_edge)\n        return img_scale, None\n\n    @staticmethod\n    def random_sample_ratio(img_scale, ratio_range):\n        \"\"\"Randomly sample an img_scale when ``ratio_range`` is specified.\n\n        A ratio will be randomly sampled from the range specified by\n        ``ratio_range``. Then it would be multiplied with ``img_scale`` to\n        generate sampled scale.\n\n        Args:\n            img_scale (tuple): Images scale base to multiply with ratio.\n            ratio_range (tuple[float]): The minimum and maximum ratio to scale\n                the ``img_scale``.\n\n        Returns:\n            (tuple, None): Returns a tuple ``(scale, None)``, where\n                ``scale`` is sampled ratio multiplied with ``img_scale`` and\n                None is just a placeholder to be consistent with\n                :func:`random_select`.\n        \"\"\"\n\n        assert isinstance(img_scale, tuple) and len(img_scale) == 2\n        min_ratio, max_ratio = ratio_range\n        assert min_ratio <= max_ratio\n        ratio = np.random.random_sample() * (max_ratio - min_ratio) + min_ratio\n        scale = int(img_scale[0] * ratio), int(img_scale[1] * ratio)\n        return scale, None\n\n    def _random_scale(self, results):\n        \"\"\"Randomly sample an img_scale according to ``ratio_range`` and\n        ``multiscale_mode``.\n\n        If ``ratio_range`` is specified, a ratio will be sampled and be\n        multiplied with ``img_scale``.\n        If multiple scales are specified by ``img_scale``, a scale will be\n        sampled according to ``multiscale_mode``.\n        Otherwise, single scale will be used.\n\n        Args:\n            results (dict): Result dict from :obj:`dataset`.\n\n        Returns:\n            dict: Two new keys 'scale` and 'scale_idx` are added into\n                ``results``, which would be used by subsequent pipelines.\n        \"\"\"\n\n        if self.ratio_range is not None:\n            if self.img_scale is None:\n                h, w = results['img'].shape[:2]\n                scale, scale_idx = self.random_sample_ratio((w, h),\n                                                            self.ratio_range)\n            else:\n                scale, scale_idx = self.random_sample_ratio(\n                    self.img_scale[0], self.ratio_range)\n        elif len(self.img_scale) == 1:\n            scale, scale_idx = self.img_scale[0], 0\n        elif self.multiscale_mode == 'range':\n            scale, scale_idx = self.random_sample(self.img_scale)\n        elif self.multiscale_mode == 'value':\n            scale, scale_idx = self.random_select(self.img_scale)\n        else:\n            raise NotImplementedError\n\n        results['scale'] = scale\n        results['scale_idx'] = scale_idx\n\n    def _resize_img(self, results):\n        \"\"\"Resize images with ``results['scale']``.\"\"\"\n        if self.keep_ratio:\n            img, scale_factor = mmcv.imrescale(\n                results['img'], results['scale'], return_scale=True)\n            # the w_scale and h_scale has minor difference\n            # a real fix should be done in the mmcv.imrescale in the future\n            new_h, new_w = img.shape[:2]\n            h, w = results['img'].shape[:2]\n            w_scale = new_w / w\n            h_scale = new_h / h\n        else:\n            img, w_scale, h_scale = mmcv.imresize(\n                results['img'], results['scale'], return_scale=True)\n        scale_factor = np.array([w_scale, h_scale, w_scale, h_scale],\n                                dtype=np.float32)\n        results['img'] = img\n        results['img_shape'] = img.shape\n        results['pad_shape'] = img.shape  # in case that there is no padding\n        results['scale_factor'] = scale_factor\n        results['keep_ratio'] = self.keep_ratio\n\n    def _resize_seg(self, results):\n        \"\"\"Resize semantic segmentation map with ``results['scale']``.\"\"\"\n        for key in results.get('seg_fields', []):\n            if self.keep_ratio:\n                gt_seg = mmcv.imrescale(\n                    results[key], results['scale'], interpolation='nearest')\n            else:\n                gt_seg = mmcv.imresize(\n                    results[key], results['scale'], interpolation='nearest')\n            results[key] = gt_seg\n\n    def __call__(self, results):\n        \"\"\"Call function to resize images, bounding boxes, masks, semantic\n        segmentation map.\n\n        Args:\n            results (dict): Result dict from loading pipeline.\n\n        Returns:\n            dict: Resized results, 'img_shape', 'pad_shape', 'scale_factor',\n                'keep_ratio' keys are added into result dict.\n        \"\"\"\n\n        if 'scale' not in results:\n            self._random_scale(results)\n        self._resize_img(results)\n        self._resize_seg(results)\n        return results\n\n    def __repr__(self):\n        repr_str = self.__class__.__name__\n        repr_str += (f'(img_scale={self.img_scale}, '\n                     f'multiscale_mode={self.multiscale_mode}, '\n                     f'ratio_range={self.ratio_range}, '\n                     f'keep_ratio={self.keep_ratio})')\n        return repr_str\n\n\n@PIPELINES.register_module()\nclass RandomFlip(object):\n    \"\"\"Flip the image & seg.\n\n    If the input dict contains the key \"flip\", then the flag will be used,\n    otherwise it will be randomly decided by a ratio specified in the init\n    method.\n\n    Args:\n        prob (float, optional): The flipping probability. Default: None.\n        direction(str, optional): The flipping direction. Options are\n            'horizontal' and 'vertical'. Default: 'horizontal'.\n    \"\"\"\n\n    @deprecated_api_warning({'flip_ratio': 'prob'}, cls_name='RandomFlip')\n    def __init__(self, prob=None, direction='horizontal'):\n        self.prob = prob\n        self.direction = direction\n        if prob is not None:\n            assert prob >= 0 and prob <= 1\n        assert direction in ['horizontal', 'vertical']\n\n    def __call__(self, results):\n        \"\"\"Call function to flip bounding boxes, masks, semantic segmentation\n        maps.\n\n        Args:\n            results (dict): Result dict from loading pipeline.\n\n        Returns:\n            dict: Flipped results, 'flip', 'flip_direction' keys are added into\n                result dict.\n        \"\"\"\n\n        if 'flip' not in results:\n            flip = True if np.random.rand() < self.prob else False\n            results['flip'] = flip\n        if 'flip_direction' not in results:\n            results['flip_direction'] = self.direction\n        if results['flip']:\n            # flip image\n            results['img'] = mmcv.imflip(\n                results['img'], direction=results['flip_direction'])\n\n            # flip segs\n            for key in results.get('seg_fields', []):\n                # use copy() to make numpy stride positive\n                results[key] = mmcv.imflip(\n                    results[key], direction=results['flip_direction']).copy()\n        return results\n\n    def __repr__(self):\n        return self.__class__.__name__ + f'(prob={self.prob})'\n\n\n@PIPELINES.register_module()\nclass Pad(object):\n    \"\"\"Pad the image & mask.\n\n    There are two padding modes: (1) pad to a fixed size and (2) pad to the\n    minimum size that is divisible by some number.\n    Added keys are \"pad_shape\", \"pad_fixed_size\", \"pad_size_divisor\",\n\n    Args:\n        size (tuple, optional): Fixed padding size.\n        size_divisor (int, optional): The divisor of padded size.\n        pad_val (float, optional): Padding value. Default: 0.\n        seg_pad_val (float, optional): Padding value of segmentation map.\n            Default: 255.\n    \"\"\"\n\n    def __init__(self,\n                 size=None,\n                 size_divisor=None,\n                 pad_val=0,\n                 seg_pad_val=255):\n        self.size = size\n        self.size_divisor = size_divisor\n        self.pad_val = pad_val\n        self.seg_pad_val = seg_pad_val\n        # only one of size and size_divisor should be valid\n        assert size is not None or size_divisor is not None\n        assert size is None or size_divisor is None\n\n    def _pad_img(self, results):\n        \"\"\"Pad images according to ``self.size``.\"\"\"\n        if self.size is not None:\n            padded_img = mmcv.impad(\n                results['img'], shape=self.size, pad_val=self.pad_val)\n        elif self.size_divisor is not None:\n            padded_img = mmcv.impad_to_multiple(\n                results['img'], self.size_divisor, pad_val=self.pad_val)\n        results['img'] = padded_img\n        results['pad_shape'] = padded_img.shape\n        results['pad_fixed_size'] = self.size\n        results['pad_size_divisor'] = self.size_divisor\n\n    def _pad_seg(self, results):\n        \"\"\"Pad masks according to ``results['pad_shape']``.\"\"\"\n        for key in results.get('seg_fields', []):\n            results[key] = mmcv.impad(\n                results[key],\n                shape=results['pad_shape'][:2],\n                pad_val=self.seg_pad_val)\n\n    def __call__(self, results):\n        \"\"\"Call function to pad images, masks, semantic segmentation maps.\n\n        Args:\n            results (dict): Result dict from loading pipeline.\n\n        Returns:\n            dict: Updated result dict.\n        \"\"\"\n\n        self._pad_img(results)\n        self._pad_seg(results)\n        return results\n\n    def __repr__(self):\n        repr_str = self.__class__.__name__\n        repr_str += f'(size={self.size}, size_divisor={self.size_divisor}, ' \\\n                    f'pad_val={self.pad_val})'\n        return repr_str\n\n\n@PIPELINES.register_module()\nclass Normalize(object):\n    \"\"\"Normalize the image.\n\n    Added key is \"img_norm_cfg\".\n\n    Args:\n        mean (sequence): Mean values of 3 channels.\n        std (sequence): Std values of 3 channels.\n        to_rgb (bool): Whether to convert the image from BGR to RGB,\n            default is true.\n    \"\"\"\n\n    def __init__(self, mean, std, to_rgb=True):\n        self.mean = np.array(mean, dtype=np.float32)\n        self.std = np.array(std, dtype=np.float32)\n        self.to_rgb = to_rgb\n\n    def __call__(self, results):\n        \"\"\"Call function to normalize images.\n\n        Args:\n            results (dict): Result dict from loading pipeline.\n\n        Returns:\n            dict: Normalized results, 'img_norm_cfg' key is added into\n                result dict.\n        \"\"\"\n\n        results['img'] = mmcv.imnormalize(results['img'], self.mean, self.std,\n                                          self.to_rgb)\n        results['img_norm_cfg'] = dict(\n            mean=self.mean, std=self.std, to_rgb=self.to_rgb)\n        return results\n\n    def __repr__(self):\n        repr_str = self.__class__.__name__\n        repr_str += f'(mean={self.mean}, std={self.std}, to_rgb=' \\\n                    f'{self.to_rgb})'\n        return repr_str\n\n\n@PIPELINES.register_module()\nclass Rerange(object):\n    \"\"\"Rerange the image pixel value.\n\n    Args:\n        min_value (float or int): Minimum value of the reranged image.\n            Default: 0.\n        max_value (float or int): Maximum value of the reranged image.\n            Default: 255.\n    \"\"\"\n\n    def __init__(self, min_value=0, max_value=255):\n        assert isinstance(min_value, float) or isinstance(min_value, int)\n        assert isinstance(max_value, float) or isinstance(max_value, int)\n        assert min_value < max_value\n        self.min_value = min_value\n        self.max_value = max_value\n\n    def __call__(self, results):\n        \"\"\"Call function to rerange images.\n\n        Args:\n            results (dict): Result dict from loading pipeline.\n        Returns:\n            dict: Reranged results.\n        \"\"\"\n\n        img = results['img']\n        img_min_value = np.min(img)\n        img_max_value = np.max(img)\n\n        assert img_min_value < img_max_value\n        # rerange to [0, 1]\n        img = (img - img_min_value) / (img_max_value - img_min_value)\n        # rerange to [min_value, max_value]\n        img = img * (self.max_value - self.min_value) + self.min_value\n        results['img'] = img\n\n        return results\n\n    def __repr__(self):\n        repr_str = self.__class__.__name__\n        repr_str += f'(min_value={self.min_value}, max_value={self.max_value})'\n        return repr_str\n\n\n@PIPELINES.register_module()\nclass CLAHE(object):\n    \"\"\"Use CLAHE method to process the image.\n\n    See `ZUIDERVELD,K. Contrast Limited Adaptive Histogram Equalization[J].\n    Graphics Gems, 1994:474-485.` for more information.\n\n    Args:\n        clip_limit (float): Threshold for contrast limiting. Default: 40.0.\n        tile_grid_size (tuple[int]): Size of grid for histogram equalization.\n            Input image will be divided into equally sized rectangular tiles.\n            It defines the number of tiles in row and column. Default: (8, 8).\n    \"\"\"\n\n    def __init__(self, clip_limit=40.0, tile_grid_size=(8, 8)):\n        assert isinstance(clip_limit, (float, int))\n        self.clip_limit = clip_limit\n        assert is_tuple_of(tile_grid_size, int)\n        assert len(tile_grid_size) == 2\n        self.tile_grid_size = tile_grid_size\n\n    def __call__(self, results):\n        \"\"\"Call function to Use CLAHE method process images.\n\n        Args:\n            results (dict): Result dict from loading pipeline.\n\n        Returns:\n            dict: Processed results.\n        \"\"\"\n\n        for i in range(results['img'].shape[2]):\n            results['img'][:, :, i] = mmcv.clahe(\n                np.array(results['img'][:, :, i], dtype=np.uint8),\n                self.clip_limit, self.tile_grid_size)\n\n        return results\n\n    def __repr__(self):\n        repr_str = self.__class__.__name__\n        repr_str += f'(clip_limit={self.clip_limit}, '\\\n                    f'tile_grid_size={self.tile_grid_size})'\n        return repr_str\n\n\n@PIPELINES.register_module()\nclass RandomCrop(object):\n    \"\"\"Random crop the image & seg.\n\n    Args:\n        crop_size (tuple): Expected size after cropping, (h, w).\n        cat_max_ratio (float): The maximum ratio that single category could\n            occupy.\n    \"\"\"\n\n    def __init__(self, crop_size, cat_max_ratio=1., ignore_index=255):\n        assert crop_size[0] > 0 and crop_size[1] > 0\n        self.crop_size = crop_size\n        self.cat_max_ratio = cat_max_ratio\n        self.ignore_index = ignore_index\n\n    def get_crop_bbox(self, img):\n        \"\"\"Randomly get a crop bounding box.\"\"\"\n        margin_h = max(img.shape[0] - self.crop_size[0], 0)\n        margin_w = max(img.shape[1] - self.crop_size[1], 0)\n        offset_h = np.random.randint(0, margin_h + 1)\n        offset_w = np.random.randint(0, margin_w + 1)\n        crop_y1, crop_y2 = offset_h, offset_h + self.crop_size[0]\n        crop_x1, crop_x2 = offset_w, offset_w + self.crop_size[1]\n\n        return crop_y1, crop_y2, crop_x1, crop_x2\n\n    def crop(self, img, crop_bbox):\n        \"\"\"Crop from ``img``\"\"\"\n        crop_y1, crop_y2, crop_x1, crop_x2 = crop_bbox\n        img = img[crop_y1:crop_y2, crop_x1:crop_x2, ...]\n        return img\n\n    def __call__(self, results):\n        \"\"\"Call function to randomly crop images, semantic segmentation maps.\n\n        Args:\n            results (dict): Result dict from loading pipeline.\n\n        Returns:\n            dict: Randomly cropped results, 'img_shape' key in result dict is\n                updated according to crop size.\n        \"\"\"\n\n        img = results['img']\n        crop_bbox = self.get_crop_bbox(img)\n        if self.cat_max_ratio < 1.:\n            # Repeat 10 times\n            for _ in range(10):\n                seg_temp = self.crop(results['gt_semantic_seg'], crop_bbox)\n                labels, cnt = np.unique(seg_temp, return_counts=True)\n                cnt = cnt[labels != self.ignore_index]\n                if len(cnt) > 1 and np.max(cnt) / np.sum(\n                        cnt) < self.cat_max_ratio:\n                    break\n                crop_bbox = self.get_crop_bbox(img)\n\n        # crop the image\n        img = self.crop(img, crop_bbox)\n        img_shape = img.shape\n        results['img'] = img\n        results['img_shape'] = img_shape\n\n        # crop semantic seg\n        for key in results.get('seg_fields', []):\n            results[key] = self.crop(results[key], crop_bbox)\n\n        return results\n\n    def __repr__(self):\n        return self.__class__.__name__ + f'(crop_size={self.crop_size})'\n\n\n@PIPELINES.register_module()\nclass RandomRotate(object):\n    \"\"\"Rotate the image & seg.\n\n    Args:\n        prob (float): The rotation probability.\n        degree (float, tuple[float]): Range of degrees to select from. If\n            degree is a number instead of tuple like (min, max),\n            the range of degree will be (``-degree``, ``+degree``)\n        pad_val (float, optional): Padding value of image. Default: 0.\n        seg_pad_val (float, optional): Padding value of segmentation map.\n            Default: 255.\n        center (tuple[float], optional): Center point (w, h) of the rotation in\n            the source image. If not specified, the center of the image will be\n            used. Default: None.\n        auto_bound (bool): Whether to adjust the image size to cover the whole\n            rotated image. Default: False\n    \"\"\"\n\n    def __init__(self,\n                 prob,\n                 degree,\n                 pad_val=0,\n                 seg_pad_val=255,\n                 center=None,\n                 auto_bound=False):\n        self.prob = prob\n        assert prob >= 0 and prob <= 1\n        if isinstance(degree, (float, int)):\n            assert degree > 0, f'degree {degree} should be positive'\n            self.degree = (-degree, degree)\n        else:\n            self.degree = degree\n        assert len(self.degree) == 2, f'degree {self.degree} should be a ' \\\n                                      f'tuple of (min, max)'\n        self.pal_val = pad_val\n        self.seg_pad_val = seg_pad_val\n        self.center = center\n        self.auto_bound = auto_bound\n\n    def __call__(self, results):\n        \"\"\"Call function to rotate image, semantic segmentation maps.\n\n        Args:\n            results (dict): Result dict from loading pipeline.\n\n        Returns:\n            dict: Rotated results.\n        \"\"\"\n\n        rotate = True if np.random.rand() < self.prob else False\n        degree = np.random.uniform(min(*self.degree), max(*self.degree))\n        if rotate:\n            # rotate image\n            results['img'] = mmcv.imrotate(\n                results['img'],\n                angle=degree,\n                border_value=self.pal_val,\n                center=self.center,\n                auto_bound=self.auto_bound)\n\n            # rotate segs\n            for key in results.get('seg_fields', []):\n                results[key] = mmcv.imrotate(\n                    results[key],\n                    angle=degree,\n                    border_value=self.seg_pad_val,\n                    center=self.center,\n                    auto_bound=self.auto_bound,\n                    interpolation='nearest')\n        return results\n\n    def __repr__(self):\n        repr_str = self.__class__.__name__\n        repr_str += f'(prob={self.prob}, ' \\\n                    f'degree={self.degree}, ' \\\n                    f'pad_val={self.pal_val}, ' \\\n                    f'seg_pad_val={self.seg_pad_val}, ' \\\n                    f'center={self.center}, ' \\\n                    f'auto_bound={self.auto_bound})'\n        return repr_str\n\n\n@PIPELINES.register_module()\nclass RGB2Gray(object):\n    \"\"\"Convert RGB image to grayscale image.\n\n    This transform calculate the weighted mean of input image channels with\n    ``weights`` and then expand the channels to ``out_channels``. When\n    ``out_channels`` is None, the number of output channels is the same as\n    input channels.\n\n    Args:\n        out_channels (int): Expected number of output channels after\n            transforming. Default: None.\n        weights (tuple[float]): The weights to calculate the weighted mean.\n            Default: (0.299, 0.587, 0.114).\n    \"\"\"\n\n    def __init__(self, out_channels=None, weights=(0.299, 0.587, 0.114)):\n        assert out_channels is None or out_channels > 0\n        self.out_channels = out_channels\n        assert isinstance(weights, tuple)\n        for item in weights:\n            assert isinstance(item, (float, int))\n        self.weights = weights\n\n    def __call__(self, results):\n        \"\"\"Call function to convert RGB image to grayscale image.\n\n        Args:\n            results (dict): Result dict from loading pipeline.\n\n        Returns:\n            dict: Result dict with grayscale image.\n        \"\"\"\n        img = results['img']\n        assert len(img.shape) == 3\n        assert img.shape[2] == len(self.weights)\n        weights = np.array(self.weights).reshape((1, 1, -1))\n        img = (img * weights).sum(2, keepdims=True)\n        if self.out_channels is None:\n            img = img.repeat(weights.shape[2], axis=2)\n        else:\n            img = img.repeat(self.out_channels, axis=2)\n\n        results['img'] = img\n        results['img_shape'] = img.shape\n\n        return results\n\n    def __repr__(self):\n        repr_str = self.__class__.__name__\n        repr_str += f'(out_channels={self.out_channels}, ' \\\n                    f'weights={self.weights})'\n        return repr_str\n\n\n@PIPELINES.register_module()\nclass AdjustGamma(object):\n    \"\"\"Using gamma correction to process the image.\n\n    Args:\n        gamma (float or int): Gamma value used in gamma correction.\n            Default: 1.0.\n    \"\"\"\n\n    def __init__(self, gamma=1.0):\n        assert isinstance(gamma, float) or isinstance(gamma, int)\n        assert gamma > 0\n        self.gamma = gamma\n        inv_gamma = 1.0 / gamma\n        self.table = np.array([(i / 255.0)**inv_gamma * 255\n                               for i in np.arange(256)]).astype('uint8')\n\n    def __call__(self, results):\n        \"\"\"Call function to process the image with gamma correction.\n\n        Args:\n            results (dict): Result dict from loading pipeline.\n\n        Returns:\n            dict: Processed results.\n        \"\"\"\n\n        results['img'] = mmcv.lut_transform(\n            np.array(results['img'], dtype=np.uint8), self.table)\n\n        return results\n\n    def __repr__(self):\n        return self.__class__.__name__ + f'(gamma={self.gamma})'\n\n\n@PIPELINES.register_module()\nclass SegRescale(object):\n    \"\"\"Rescale semantic segmentation maps.\n\n    Args:\n        scale_factor (float): The scale factor of the final output.\n    \"\"\"\n\n    def __init__(self, scale_factor=1):\n        self.scale_factor = scale_factor\n\n    def __call__(self, results):\n        \"\"\"Call function to scale the semantic segmentation map.\n\n        Args:\n            results (dict): Result dict from loading pipeline.\n\n        Returns:\n            dict: Result dict with semantic segmentation map scaled.\n        \"\"\"\n        for key in results.get('seg_fields', []):\n            if self.scale_factor != 1:\n                results[key] = mmcv.imrescale(\n                    results[key], self.scale_factor, interpolation='nearest')\n        return results\n\n    def __repr__(self):\n        return self.__class__.__name__ + f'(scale_factor={self.scale_factor})'\n\n\n@PIPELINES.register_module()\nclass PhotoMetricDistortion(object):\n    \"\"\"Apply photometric distortion to image sequentially, every transformation\n    is applied with a probability of 0.5. The position of random contrast is in\n    second or second to last.\n\n    1. random brightness\n    2. random contrast (mode 0)\n    3. convert color from BGR to HSV\n    4. random saturation\n    5. random hue\n    6. convert color from HSV to BGR\n    7. random contrast (mode 1)\n    8. randomly swap channels\n\n    Args:\n        brightness_delta (int): delta of brightness.\n        contrast_range (tuple): range of contrast.\n        saturation_range (tuple): range of saturation.\n        hue_delta (int): delta of hue.\n    \"\"\"\n\n    def __init__(self,\n                 brightness_delta=32,\n                 contrast_range=(0.5, 1.5),\n                 saturation_range=(0.5, 1.5),\n                 hue_delta=18):\n        self.brightness_delta = brightness_delta\n        self.contrast_lower, self.contrast_upper = contrast_range\n        self.saturation_lower, self.saturation_upper = saturation_range\n        self.hue_delta = hue_delta\n\n    def convert(self, img, alpha=1, beta=0):\n        \"\"\"Multiple with alpha and add beat with clip.\"\"\"\n        img = img.astype(np.float32) * alpha + beta\n        img = np.clip(img, 0, 255)\n        return img.astype(np.uint8)\n\n    def brightness(self, img):\n        \"\"\"Brightness distortion.\"\"\"\n        if random.randint(2):\n            return self.convert(\n                img,\n                beta=random.uniform(-self.brightness_delta,\n                                    self.brightness_delta))\n        return img\n\n    def contrast(self, img):\n        \"\"\"Contrast distortion.\"\"\"\n        if random.randint(2):\n            return self.convert(\n                img,\n                alpha=random.uniform(self.contrast_lower, self.contrast_upper))\n        return img\n\n    def saturation(self, img):\n        \"\"\"Saturation distortion.\"\"\"\n        if random.randint(2):\n            img = mmcv.bgr2hsv(img)\n            img[:, :, 1] = self.convert(\n                img[:, :, 1],\n                alpha=random.uniform(self.saturation_lower,\n                                     self.saturation_upper))\n            img = mmcv.hsv2bgr(img)\n        return img\n\n    def hue(self, img):\n        \"\"\"Hue distortion.\"\"\"\n        if random.randint(2):\n            img = mmcv.bgr2hsv(img)\n            img[:, :,\n                0] = (img[:, :, 0].astype(int) +\n                      random.randint(-self.hue_delta, self.hue_delta)) % 180\n            img = mmcv.hsv2bgr(img)\n        return img\n\n    def __call__(self, results):\n        \"\"\"Call function to perform photometric distortion on images.\n\n        Args:\n            results (dict): Result dict from loading pipeline.\n\n        Returns:\n            dict: Result dict with images distorted.\n        \"\"\"\n\n        img = results['img']\n        # random brightness\n        img = self.brightness(img)\n\n        # mode == 0 --> do random contrast first\n        # mode == 1 --> do random contrast last\n        mode = random.randint(2)\n        if mode == 1:\n            img = self.contrast(img)\n\n        # random saturation\n        img = self.saturation(img)\n\n        # random hue\n        img = self.hue(img)\n\n        # random contrast\n        if mode == 0:\n            img = self.contrast(img)\n\n        results['img'] = img\n        return results\n\n    def __repr__(self):\n        repr_str = self.__class__.__name__\n        repr_str += (f'(brightness_delta={self.brightness_delta}, '\n                     f'contrast_range=({self.contrast_lower}, '\n                     f'{self.contrast_upper}), '\n                     f'saturation_range=({self.saturation_lower}, '\n                     f'{self.saturation_upper}), '\n                     f'hue_delta={self.hue_delta})')\n        return repr_str\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/mmseg/datasets/stare.py",
    "content": "import os.path as osp\n\nfrom .builder import DATASETS\nfrom .custom import CustomDataset\n\n\n@DATASETS.register_module()\nclass STAREDataset(CustomDataset):\n    \"\"\"STARE dataset.\n\n    In segmentation map annotation for STARE, 0 stands for background, which is\n    included in 2 categories. ``reduce_zero_label`` is fixed to False. The\n    ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to\n    '.ah.png'.\n    \"\"\"\n\n    CLASSES = ('background', 'vessel')\n\n    PALETTE = [[120, 120, 120], [6, 230, 230]]\n\n    def __init__(self, **kwargs):\n        super(STAREDataset, self).__init__(\n            img_suffix='.png',\n            seg_map_suffix='.ah.png',\n            reduce_zero_label=False,\n            **kwargs)\n        assert osp.exists(self.img_dir)\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/mmseg/datasets/voc.py",
    "content": "import os.path as osp\n\nfrom .builder import DATASETS\nfrom .custom import CustomDataset\n\n\n@DATASETS.register_module()\nclass PascalVOCDataset(CustomDataset):\n    \"\"\"Pascal VOC dataset.\n\n    Args:\n        split (str): Split txt file for Pascal VOC.\n    \"\"\"\n\n    CLASSES = ('background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle',\n               'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog',\n               'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa',\n               'train', 'tvmonitor')\n\n    PALETTE = [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], [0, 0, 128],\n               [128, 0, 128], [0, 128, 128], [128, 128, 128], [64, 0, 0],\n               [192, 0, 0], [64, 128, 0], [192, 128, 0], [64, 0, 128],\n               [192, 0, 128], [64, 128, 128], [192, 128, 128], [0, 64, 0],\n               [128, 64, 0], [0, 192, 0], [128, 192, 0], [0, 64, 128]]\n\n    def __init__(self, split, **kwargs):\n        super(PascalVOCDataset, self).__init__(\n            img_suffix='.jpg', seg_map_suffix='.png', split=split, **kwargs)\n        assert osp.exists(self.img_dir) and self.split is not None\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/mmseg/models/__init__.py",
    "content": "from .backbones import *  # noqa: F401,F403\nfrom .builder import (BACKBONES, HEADS, LOSSES, SEGMENTORS, build_backbone,\n                      build_head, build_loss, build_segmentor)\nfrom .decode_heads import *  # noqa: F401,F403\nfrom .losses import *  # noqa: F401,F403\nfrom .necks import *  # noqa: F401,F403\nfrom .segmentors import *  # noqa: F401,F403\n\n__all__ = [\n    'BACKBONES', 'HEADS', 'LOSSES', 'SEGMENTORS', 'build_backbone',\n    'build_head', 'build_loss', 'build_segmentor'\n]\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/mmseg/models/backbones/__init__.py",
    "content": "from .cgnet import CGNet\nfrom .fast_scnn import FastSCNN\nfrom .hrnet import HRNet\nfrom .mobilenet_v2 import MobileNetV2\nfrom .mobilenet_v3 import MobileNetV3\nfrom .resnest import ResNeSt\nfrom .resnet import ResNet, ResNetV1c, ResNetV1d\nfrom .resnext import ResNeXt\nfrom .unet import UNet\n\n__all__ = [\n    'ResNet', 'ResNetV1c', 'ResNetV1d', 'ResNeXt', 'HRNet', 'FastSCNN',\n    'ResNeSt', 'MobileNetV2', 'UNet', 'CGNet', 'MobileNetV3'\n]\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/mmseg/models/backbones/cgnet.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.utils.checkpoint as cp\nfrom mmcv.cnn import (ConvModule, build_conv_layer, build_norm_layer,\n                      constant_init, kaiming_init)\nfrom mmcv.runner import load_checkpoint\nfrom mmcv.utils.parrots_wrapper import _BatchNorm\n\nfrom mmseg.utils import get_root_logger\nfrom ..builder import BACKBONES\n\n\nclass GlobalContextExtractor(nn.Module):\n    \"\"\"Global Context Extractor for CGNet.\n\n    This class is employed to refine the joFint feature of both local feature\n    and surrounding context.\n\n    Args:\n        channel (int): Number of input feature channels.\n        reduction (int): Reductions for global context extractor. Default: 16.\n        with_cp (bool): Use checkpoint or not. Using checkpoint will save some\n            memory while slowing down the training speed. Default: False.\n    \"\"\"\n\n    def __init__(self, channel, reduction=16, with_cp=False):\n        super(GlobalContextExtractor, self).__init__()\n        self.channel = channel\n        self.reduction = reduction\n        assert reduction >= 1 and channel >= reduction\n        self.with_cp = with_cp\n        self.avg_pool = nn.AdaptiveAvgPool2d(1)\n        self.fc = nn.Sequential(\n            nn.Linear(channel, channel // reduction), nn.ReLU(inplace=True),\n            nn.Linear(channel // reduction, channel), nn.Sigmoid())\n\n    def forward(self, x):\n\n        def _inner_forward(x):\n            num_batch, num_channel = x.size()[:2]\n            y = self.avg_pool(x).view(num_batch, num_channel)\n            y = self.fc(y).view(num_batch, num_channel, 1, 1)\n            return x * y\n\n        if self.with_cp and x.requires_grad:\n            out = cp.checkpoint(_inner_forward, x)\n        else:\n            out = _inner_forward(x)\n\n        return out\n\n\nclass ContextGuidedBlock(nn.Module):\n    \"\"\"Context Guided Block for CGNet.\n\n    This class consists of four components: local feature extractor,\n    surrounding feature extractor, joint feature extractor and global\n    context extractor.\n\n    Args:\n        in_channels (int): Number of input feature channels.\n        out_channels (int): Number of output feature channels.\n        dilation (int): Dilation rate for surrounding context extractor.\n            Default: 2.\n        reduction (int): Reduction for global context extractor. Default: 16.\n        skip_connect (bool): Add input to output or not. Default: True.\n        downsample (bool): Downsample the input to 1/2 or not. Default: False.\n        conv_cfg (dict): Config dict for convolution layer.\n            Default: None, which means using conv2d.\n        norm_cfg (dict): Config dict for normalization layer.\n            Default: dict(type='BN', requires_grad=True).\n        act_cfg (dict): Config dict for activation layer.\n            Default: dict(type='PReLU').\n        with_cp (bool): Use checkpoint or not. Using checkpoint will save some\n            memory while slowing down the training speed. Default: False.\n    \"\"\"\n\n    def __init__(self,\n                 in_channels,\n                 out_channels,\n                 dilation=2,\n                 reduction=16,\n                 skip_connect=True,\n                 downsample=False,\n                 conv_cfg=None,\n                 norm_cfg=dict(type='BN', requires_grad=True),\n                 act_cfg=dict(type='PReLU'),\n                 with_cp=False):\n        super(ContextGuidedBlock, self).__init__()\n        self.with_cp = with_cp\n        self.downsample = downsample\n\n        channels = out_channels if downsample else out_channels // 2\n        if 'type' in act_cfg and act_cfg['type'] == 'PReLU':\n            act_cfg['num_parameters'] = channels\n        kernel_size = 3 if downsample else 1\n        stride = 2 if downsample else 1\n        padding = (kernel_size - 1) // 2\n\n        self.conv1x1 = ConvModule(\n            in_channels,\n            channels,\n            kernel_size,\n            stride,\n            padding,\n            conv_cfg=conv_cfg,\n            norm_cfg=norm_cfg,\n            act_cfg=act_cfg)\n\n        self.f_loc = build_conv_layer(\n            conv_cfg,\n            channels,\n            channels,\n            kernel_size=3,\n            padding=1,\n            groups=channels,\n            bias=False)\n        self.f_sur = build_conv_layer(\n            conv_cfg,\n            channels,\n            channels,\n            kernel_size=3,\n            padding=dilation,\n            groups=channels,\n            dilation=dilation,\n            bias=False)\n\n        self.bn = build_norm_layer(norm_cfg, 2 * channels)[1]\n        self.activate = nn.PReLU(2 * channels)\n\n        if downsample:\n            self.bottleneck = build_conv_layer(\n                conv_cfg,\n                2 * channels,\n                out_channels,\n                kernel_size=1,\n                bias=False)\n\n        self.skip_connect = skip_connect and not downsample\n        self.f_glo = GlobalContextExtractor(out_channels, reduction, with_cp)\n\n    def forward(self, x):\n\n        def _inner_forward(x):\n            out = self.conv1x1(x)\n            loc = self.f_loc(out)\n            sur = self.f_sur(out)\n\n            joi_feat = torch.cat([loc, sur], 1)  # the joint feature\n            joi_feat = self.bn(joi_feat)\n            joi_feat = self.activate(joi_feat)\n            if self.downsample:\n                joi_feat = self.bottleneck(joi_feat)  # channel = out_channels\n            # f_glo is employed to refine the joint feature\n            out = self.f_glo(joi_feat)\n\n            if self.skip_connect:\n                return x + out\n            else:\n                return out\n\n        if self.with_cp and x.requires_grad:\n            out = cp.checkpoint(_inner_forward, x)\n        else:\n            out = _inner_forward(x)\n\n        return out\n\n\nclass InputInjection(nn.Module):\n    \"\"\"Downsampling module for CGNet.\"\"\"\n\n    def __init__(self, num_downsampling):\n        super(InputInjection, self).__init__()\n        self.pool = nn.ModuleList()\n        for i in range(num_downsampling):\n            self.pool.append(nn.AvgPool2d(3, stride=2, padding=1))\n\n    def forward(self, x):\n        for pool in self.pool:\n            x = pool(x)\n        return x\n\n\n@BACKBONES.register_module()\nclass CGNet(nn.Module):\n    \"\"\"CGNet backbone.\n\n    A Light-weight Context Guided Network for Semantic Segmentation\n    arXiv: https://arxiv.org/abs/1811.08201\n\n    Args:\n        in_channels (int): Number of input image channels. Normally 3.\n        num_channels (tuple[int]): Numbers of feature channels at each stages.\n            Default: (32, 64, 128).\n        num_blocks (tuple[int]): Numbers of CG blocks at stage 1 and stage 2.\n            Default: (3, 21).\n        dilations (tuple[int]): Dilation rate for surrounding context\n            extractors at stage 1 and stage 2. Default: (2, 4).\n        reductions (tuple[int]): Reductions for global context extractors at\n            stage 1 and stage 2. Default: (8, 16).\n        conv_cfg (dict): Config dict for convolution layer.\n            Default: None, which means using conv2d.\n        norm_cfg (dict): Config dict for normalization layer.\n            Default: dict(type='BN', requires_grad=True).\n        act_cfg (dict): Config dict for activation layer.\n            Default: dict(type='PReLU').\n        norm_eval (bool): Whether to set norm layers to eval mode, namely,\n            freeze running stats (mean and var). Note: Effect on Batch Norm\n            and its variants only. Default: False.\n        with_cp (bool): Use checkpoint or not. Using checkpoint will save some\n            memory while slowing down the training speed. Default: False.\n    \"\"\"\n\n    def __init__(self,\n                 in_channels=3,\n                 num_channels=(32, 64, 128),\n                 num_blocks=(3, 21),\n                 dilations=(2, 4),\n                 reductions=(8, 16),\n                 conv_cfg=None,\n                 norm_cfg=dict(type='BN', requires_grad=True),\n                 act_cfg=dict(type='PReLU'),\n                 norm_eval=False,\n                 with_cp=False):\n\n        super(CGNet, self).__init__()\n        self.in_channels = in_channels\n        self.num_channels = num_channels\n        assert isinstance(self.num_channels, tuple) and len(\n            self.num_channels) == 3\n        self.num_blocks = num_blocks\n        assert isinstance(self.num_blocks, tuple) and len(self.num_blocks) == 2\n        self.dilations = dilations\n        assert isinstance(self.dilations, tuple) and len(self.dilations) == 2\n        self.reductions = reductions\n        assert isinstance(self.reductions, tuple) and len(self.reductions) == 2\n        self.conv_cfg = conv_cfg\n        self.norm_cfg = norm_cfg\n        self.act_cfg = act_cfg\n        if 'type' in self.act_cfg and self.act_cfg['type'] == 'PReLU':\n            self.act_cfg['num_parameters'] = num_channels[0]\n        self.norm_eval = norm_eval\n        self.with_cp = with_cp\n\n        cur_channels = in_channels\n        self.stem = nn.ModuleList()\n        for i in range(3):\n            self.stem.append(\n                ConvModule(\n                    cur_channels,\n                    num_channels[0],\n                    3,\n                    2 if i == 0 else 1,\n                    padding=1,\n                    conv_cfg=conv_cfg,\n                    norm_cfg=norm_cfg,\n                    act_cfg=act_cfg))\n            cur_channels = num_channels[0]\n\n        self.inject_2x = InputInjection(1)  # down-sample for Input, factor=2\n        self.inject_4x = InputInjection(2)  # down-sample for Input, factor=4\n\n        cur_channels += in_channels\n        self.norm_prelu_0 = nn.Sequential(\n            build_norm_layer(norm_cfg, cur_channels)[1],\n            nn.PReLU(cur_channels))\n\n        # stage 1\n        self.level1 = nn.ModuleList()\n        for i in range(num_blocks[0]):\n            self.level1.append(\n                ContextGuidedBlock(\n                    cur_channels if i == 0 else num_channels[1],\n                    num_channels[1],\n                    dilations[0],\n                    reductions[0],\n                    downsample=(i == 0),\n                    conv_cfg=conv_cfg,\n                    norm_cfg=norm_cfg,\n                    act_cfg=act_cfg,\n                    with_cp=with_cp))  # CG block\n\n        cur_channels = 2 * num_channels[1] + in_channels\n        self.norm_prelu_1 = nn.Sequential(\n            build_norm_layer(norm_cfg, cur_channels)[1],\n            nn.PReLU(cur_channels))\n\n        # stage 2\n        self.level2 = nn.ModuleList()\n        for i in range(num_blocks[1]):\n            self.level2.append(\n                ContextGuidedBlock(\n                    cur_channels if i == 0 else num_channels[2],\n                    num_channels[2],\n                    dilations[1],\n                    reductions[1],\n                    downsample=(i == 0),\n                    conv_cfg=conv_cfg,\n                    norm_cfg=norm_cfg,\n                    act_cfg=act_cfg,\n                    with_cp=with_cp))  # CG block\n\n        cur_channels = 2 * num_channels[2]\n        self.norm_prelu_2 = nn.Sequential(\n            build_norm_layer(norm_cfg, cur_channels)[1],\n            nn.PReLU(cur_channels))\n\n    def forward(self, x):\n        output = []\n\n        # stage 0\n        inp_2x = self.inject_2x(x)\n        inp_4x = self.inject_4x(x)\n        for layer in self.stem:\n            x = layer(x)\n        x = self.norm_prelu_0(torch.cat([x, inp_2x], 1))\n        output.append(x)\n\n        # stage 1\n        for i, layer in enumerate(self.level1):\n            x = layer(x)\n            if i == 0:\n                down1 = x\n        x = self.norm_prelu_1(torch.cat([x, down1, inp_4x], 1))\n        output.append(x)\n\n        # stage 2\n        for i, layer in enumerate(self.level2):\n            x = layer(x)\n            if i == 0:\n                down2 = x\n        x = self.norm_prelu_2(torch.cat([down2, x], 1))\n        output.append(x)\n\n        return output\n\n    def init_weights(self, pretrained=None):\n        \"\"\"Initialize the weights in backbone.\n\n        Args:\n            pretrained (str, optional): Path to pre-trained weights.\n                Defaults to None.\n        \"\"\"\n        if isinstance(pretrained, str):\n            logger = get_root_logger()\n            load_checkpoint(self, pretrained, strict=False, logger=logger)\n        elif pretrained is None:\n            for m in self.modules():\n                if isinstance(m, (nn.Conv2d, nn.Linear)):\n                    kaiming_init(m)\n                elif isinstance(m, (_BatchNorm, nn.GroupNorm)):\n                    constant_init(m, 1)\n                elif isinstance(m, nn.PReLU):\n                    constant_init(m, 0)\n        else:\n            raise TypeError('pretrained must be a str or None')\n\n    def train(self, mode=True):\n        \"\"\"Convert the model into training mode whill keeping the normalization\n        layer freezed.\"\"\"\n        super(CGNet, self).train(mode)\n        if mode and self.norm_eval:\n            for m in self.modules():\n                # trick: eval have effect on BatchNorm only\n                if isinstance(m, _BatchNorm):\n                    m.eval()\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/mmseg/models/backbones/fast_scnn.py",
    "content": "import torch\nimport torch.nn as nn\nfrom mmcv.cnn import (ConvModule, DepthwiseSeparableConvModule, constant_init,\n                      kaiming_init)\nfrom torch.nn.modules.batchnorm import _BatchNorm\n\nfrom mmseg.models.decode_heads.psp_head import PPM\nfrom mmseg.ops import resize\nfrom ..builder import BACKBONES\nfrom ..utils.inverted_residual import InvertedResidual\n\n\nclass LearningToDownsample(nn.Module):\n    \"\"\"Learning to downsample module.\n\n    Args:\n        in_channels (int): Number of input channels.\n        dw_channels (tuple[int]): Number of output channels of the first and\n            the second depthwise conv (dwconv) layers.\n        out_channels (int): Number of output channels of the whole\n            'learning to downsample' module.\n        conv_cfg (dict | None): Config of conv layers. Default: None\n        norm_cfg (dict | None): Config of norm layers. Default:\n            dict(type='BN')\n        act_cfg (dict): Config of activation layers. Default:\n            dict(type='ReLU')\n    \"\"\"\n\n    def __init__(self,\n                 in_channels,\n                 dw_channels,\n                 out_channels,\n                 conv_cfg=None,\n                 norm_cfg=dict(type='BN'),\n                 act_cfg=dict(type='ReLU')):\n        super(LearningToDownsample, self).__init__()\n        self.conv_cfg = conv_cfg\n        self.norm_cfg = norm_cfg\n        self.act_cfg = act_cfg\n        dw_channels1 = dw_channels[0]\n        dw_channels2 = dw_channels[1]\n\n        self.conv = ConvModule(\n            in_channels,\n            dw_channels1,\n            3,\n            stride=2,\n            conv_cfg=self.conv_cfg,\n            norm_cfg=self.norm_cfg,\n            act_cfg=self.act_cfg)\n        self.dsconv1 = DepthwiseSeparableConvModule(\n            dw_channels1,\n            dw_channels2,\n            kernel_size=3,\n            stride=2,\n            padding=1,\n            norm_cfg=self.norm_cfg)\n        self.dsconv2 = DepthwiseSeparableConvModule(\n            dw_channels2,\n            out_channels,\n            kernel_size=3,\n            stride=2,\n            padding=1,\n            norm_cfg=self.norm_cfg)\n\n    def forward(self, x):\n        x = self.conv(x)\n        x = self.dsconv1(x)\n        x = self.dsconv2(x)\n        return x\n\n\nclass GlobalFeatureExtractor(nn.Module):\n    \"\"\"Global feature extractor module.\n\n    Args:\n        in_channels (int): Number of input channels of the GFE module.\n            Default: 64\n        block_channels (tuple[int]): Tuple of ints. Each int specifies the\n            number of output channels of each Inverted Residual module.\n            Default: (64, 96, 128)\n        out_channels(int): Number of output channels of the GFE module.\n            Default: 128\n        expand_ratio (int): Adjusts number of channels of the hidden layer\n            in InvertedResidual by this amount.\n            Default: 6\n        num_blocks (tuple[int]): Tuple of ints. Each int specifies the\n            number of times each Inverted Residual module is repeated.\n            The repeated Inverted Residual modules are called a 'group'.\n            Default: (3, 3, 3)\n        strides (tuple[int]): Tuple of ints. Each int specifies\n            the downsampling factor of each 'group'.\n            Default: (2, 2, 1)\n        pool_scales (tuple[int]): Tuple of ints. Each int specifies\n            the parameter required in 'global average pooling' within PPM.\n            Default: (1, 2, 3, 6)\n        conv_cfg (dict | None): Config of conv layers. Default: None\n        norm_cfg (dict | None): Config of norm layers. Default:\n            dict(type='BN')\n        act_cfg (dict): Config of activation layers. Default:\n            dict(type='ReLU')\n        align_corners (bool): align_corners argument of F.interpolate.\n            Default: False\n    \"\"\"\n\n    def __init__(self,\n                 in_channels=64,\n                 block_channels=(64, 96, 128),\n                 out_channels=128,\n                 expand_ratio=6,\n                 num_blocks=(3, 3, 3),\n                 strides=(2, 2, 1),\n                 pool_scales=(1, 2, 3, 6),\n                 conv_cfg=None,\n                 norm_cfg=dict(type='BN'),\n                 act_cfg=dict(type='ReLU'),\n                 align_corners=False):\n        super(GlobalFeatureExtractor, self).__init__()\n        self.conv_cfg = conv_cfg\n        self.norm_cfg = norm_cfg\n        self.act_cfg = act_cfg\n        assert len(block_channels) == len(num_blocks) == 3\n        self.bottleneck1 = self._make_layer(in_channels, block_channels[0],\n                                            num_blocks[0], strides[0],\n                                            expand_ratio)\n        self.bottleneck2 = self._make_layer(block_channels[0],\n                                            block_channels[1], num_blocks[1],\n                                            strides[1], expand_ratio)\n        self.bottleneck3 = self._make_layer(block_channels[1],\n                                            block_channels[2], num_blocks[2],\n                                            strides[2], expand_ratio)\n        self.ppm = PPM(\n            pool_scales,\n            block_channels[2],\n            block_channels[2] // 4,\n            conv_cfg=self.conv_cfg,\n            norm_cfg=self.norm_cfg,\n            act_cfg=self.act_cfg,\n            align_corners=align_corners)\n        self.out = ConvModule(\n            block_channels[2] * 2,\n            out_channels,\n            1,\n            conv_cfg=self.conv_cfg,\n            norm_cfg=self.norm_cfg,\n            act_cfg=self.act_cfg)\n\n    def _make_layer(self,\n                    in_channels,\n                    out_channels,\n                    blocks,\n                    stride=1,\n                    expand_ratio=6):\n        layers = [\n            InvertedResidual(\n                in_channels,\n                out_channels,\n                stride,\n                expand_ratio,\n                norm_cfg=self.norm_cfg)\n        ]\n        for i in range(1, blocks):\n            layers.append(\n                InvertedResidual(\n                    out_channels,\n                    out_channels,\n                    1,\n                    expand_ratio,\n                    norm_cfg=self.norm_cfg))\n        return nn.Sequential(*layers)\n\n    def forward(self, x):\n        x = self.bottleneck1(x)\n        x = self.bottleneck2(x)\n        x = self.bottleneck3(x)\n        x = torch.cat([x, *self.ppm(x)], dim=1)\n        x = self.out(x)\n        return x\n\n\nclass FeatureFusionModule(nn.Module):\n    \"\"\"Feature fusion module.\n\n    Args:\n        higher_in_channels (int): Number of input channels of the\n            higher-resolution branch.\n        lower_in_channels (int): Number of input channels of the\n            lower-resolution branch.\n        out_channels (int): Number of output channels.\n        conv_cfg (dict | None): Config of conv layers. Default: None\n        norm_cfg (dict | None): Config of norm layers. Default:\n            dict(type='BN')\n        act_cfg (dict): Config of activation layers. Default:\n            dict(type='ReLU')\n        align_corners (bool): align_corners argument of F.interpolate.\n            Default: False\n    \"\"\"\n\n    def __init__(self,\n                 higher_in_channels,\n                 lower_in_channels,\n                 out_channels,\n                 conv_cfg=None,\n                 norm_cfg=dict(type='BN'),\n                 act_cfg=dict(type='ReLU'),\n                 align_corners=False):\n        super(FeatureFusionModule, self).__init__()\n        self.conv_cfg = conv_cfg\n        self.norm_cfg = norm_cfg\n        self.act_cfg = act_cfg\n        self.align_corners = align_corners\n        self.dwconv = ConvModule(\n            lower_in_channels,\n            out_channels,\n            1,\n            conv_cfg=self.conv_cfg,\n            norm_cfg=self.norm_cfg,\n            act_cfg=self.act_cfg)\n        self.conv_lower_res = ConvModule(\n            out_channels,\n            out_channels,\n            1,\n            conv_cfg=self.conv_cfg,\n            norm_cfg=self.norm_cfg,\n            act_cfg=None)\n        self.conv_higher_res = ConvModule(\n            higher_in_channels,\n            out_channels,\n            1,\n            conv_cfg=self.conv_cfg,\n            norm_cfg=self.norm_cfg,\n            act_cfg=None)\n        self.relu = nn.ReLU(True)\n\n    def forward(self, higher_res_feature, lower_res_feature):\n        lower_res_feature = resize(\n            lower_res_feature,\n            size=higher_res_feature.size()[2:],\n            mode='bilinear',\n            align_corners=self.align_corners)\n        lower_res_feature = self.dwconv(lower_res_feature)\n        lower_res_feature = self.conv_lower_res(lower_res_feature)\n\n        higher_res_feature = self.conv_higher_res(higher_res_feature)\n        out = higher_res_feature + lower_res_feature\n        return self.relu(out)\n\n\n@BACKBONES.register_module()\nclass FastSCNN(nn.Module):\n    \"\"\"Fast-SCNN Backbone.\n\n    Args:\n        in_channels (int): Number of input image channels. Default: 3.\n        downsample_dw_channels (tuple[int]): Number of output channels after\n            the first conv layer & the second conv layer in\n            Learning-To-Downsample (LTD) module.\n            Default: (32, 48).\n        global_in_channels (int): Number of input channels of\n            Global Feature Extractor(GFE).\n            Equal to number of output channels of LTD.\n            Default: 64.\n        global_block_channels (tuple[int]): Tuple of integers that describe\n            the output channels for each of the MobileNet-v2 bottleneck\n            residual blocks in GFE.\n            Default: (64, 96, 128).\n        global_block_strides (tuple[int]): Tuple of integers\n            that describe the strides (downsampling factors) for each of the\n            MobileNet-v2 bottleneck residual blocks in GFE.\n            Default: (2, 2, 1).\n        global_out_channels (int): Number of output channels of GFE.\n            Default: 128.\n        higher_in_channels (int): Number of input channels of the higher\n            resolution branch in FFM.\n            Equal to global_in_channels.\n            Default: 64.\n        lower_in_channels (int): Number of input channels of  the lower\n            resolution branch in FFM.\n            Equal to global_out_channels.\n            Default: 128.\n        fusion_out_channels (int): Number of output channels of FFM.\n            Default: 128.\n        out_indices (tuple): Tuple of indices of list\n            [higher_res_features, lower_res_features, fusion_output].\n            Often set to (0,1,2) to enable aux. heads.\n            Default: (0, 1, 2).\n        conv_cfg (dict | None): Config of conv layers. Default: None\n        norm_cfg (dict | None): Config of norm layers. Default:\n            dict(type='BN')\n        act_cfg (dict): Config of activation layers. Default:\n            dict(type='ReLU')\n        align_corners (bool): align_corners argument of F.interpolate.\n            Default: False\n    \"\"\"\n\n    def __init__(self,\n                 in_channels=3,\n                 downsample_dw_channels=(32, 48),\n                 global_in_channels=64,\n                 global_block_channels=(64, 96, 128),\n                 global_block_strides=(2, 2, 1),\n                 global_out_channels=128,\n                 higher_in_channels=64,\n                 lower_in_channels=128,\n                 fusion_out_channels=128,\n                 out_indices=(0, 1, 2),\n                 conv_cfg=None,\n                 norm_cfg=dict(type='BN'),\n                 act_cfg=dict(type='ReLU'),\n                 align_corners=False):\n\n        super(FastSCNN, self).__init__()\n        if global_in_channels != higher_in_channels:\n            raise AssertionError('Global Input Channels must be the same \\\n                                 with Higher Input Channels!')\n        elif global_out_channels != lower_in_channels:\n            raise AssertionError('Global Output Channels must be the same \\\n                                with Lower Input Channels!')\n\n        self.in_channels = in_channels\n        self.downsample_dw_channels1 = downsample_dw_channels[0]\n        self.downsample_dw_channels2 = downsample_dw_channels[1]\n        self.global_in_channels = global_in_channels\n        self.global_block_channels = global_block_channels\n        self.global_block_strides = global_block_strides\n        self.global_out_channels = global_out_channels\n        self.higher_in_channels = higher_in_channels\n        self.lower_in_channels = lower_in_channels\n        self.fusion_out_channels = fusion_out_channels\n        self.out_indices = out_indices\n        self.conv_cfg = conv_cfg\n        self.norm_cfg = norm_cfg\n        self.act_cfg = act_cfg\n        self.align_corners = align_corners\n        self.learning_to_downsample = LearningToDownsample(\n            in_channels,\n            downsample_dw_channels,\n            global_in_channels,\n            conv_cfg=self.conv_cfg,\n            norm_cfg=self.norm_cfg,\n            act_cfg=self.act_cfg)\n        self.global_feature_extractor = GlobalFeatureExtractor(\n            global_in_channels,\n            global_block_channels,\n            global_out_channels,\n            strides=self.global_block_strides,\n            conv_cfg=self.conv_cfg,\n            norm_cfg=self.norm_cfg,\n            act_cfg=self.act_cfg,\n            align_corners=self.align_corners)\n        self.feature_fusion = FeatureFusionModule(\n            higher_in_channels,\n            lower_in_channels,\n            fusion_out_channels,\n            conv_cfg=self.conv_cfg,\n            norm_cfg=self.norm_cfg,\n            act_cfg=self.act_cfg,\n            align_corners=self.align_corners)\n\n    def init_weights(self, pretrained=None):\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                kaiming_init(m)\n            elif isinstance(m, (_BatchNorm, nn.GroupNorm)):\n                constant_init(m, 1)\n\n    def forward(self, x):\n        higher_res_features = self.learning_to_downsample(x)\n        lower_res_features = self.global_feature_extractor(higher_res_features)\n        fusion_output = self.feature_fusion(higher_res_features,\n                                            lower_res_features)\n\n        outs = [higher_res_features, lower_res_features, fusion_output]\n        outs = [outs[i] for i in self.out_indices]\n        return tuple(outs)\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/mmseg/models/backbones/hrnet.py",
    "content": "import torch.nn as nn\nfrom mmcv.cnn import (build_conv_layer, build_norm_layer, constant_init,\n                      kaiming_init)\nfrom mmcv.runner import load_checkpoint\nfrom mmcv.utils.parrots_wrapper import _BatchNorm\n\nfrom mmseg.ops import Upsample, resize\nfrom mmseg.utils import get_root_logger\nfrom ..builder import BACKBONES\nfrom .resnet import BasicBlock, Bottleneck\n\n\nclass HRModule(nn.Module):\n    \"\"\"High-Resolution Module for HRNet.\n\n    In this module, every branch has 4 BasicBlocks/Bottlenecks. Fusion/Exchange\n    is in this module.\n    \"\"\"\n\n    def __init__(self,\n                 num_branches,\n                 blocks,\n                 num_blocks,\n                 in_channels,\n                 num_channels,\n                 multiscale_output=True,\n                 with_cp=False,\n                 conv_cfg=None,\n                 norm_cfg=dict(type='BN', requires_grad=True)):\n        super(HRModule, self).__init__()\n        self._check_branches(num_branches, num_blocks, in_channels,\n                             num_channels)\n\n        self.in_channels = in_channels\n        self.num_branches = num_branches\n\n        self.multiscale_output = multiscale_output\n        self.norm_cfg = norm_cfg\n        self.conv_cfg = conv_cfg\n        self.with_cp = with_cp\n        self.branches = self._make_branches(num_branches, blocks, num_blocks,\n                                            num_channels)\n        self.fuse_layers = self._make_fuse_layers()\n        self.relu = nn.ReLU(inplace=False)\n\n    def _check_branches(self, num_branches, num_blocks, in_channels,\n                        num_channels):\n        \"\"\"Check branches configuration.\"\"\"\n        if num_branches != len(num_blocks):\n            error_msg = f'NUM_BRANCHES({num_branches}) <> NUM_BLOCKS(' \\\n                        f'{len(num_blocks)})'\n            raise ValueError(error_msg)\n\n        if num_branches != len(num_channels):\n            error_msg = f'NUM_BRANCHES({num_branches}) <> NUM_CHANNELS(' \\\n                        f'{len(num_channels)})'\n            raise ValueError(error_msg)\n\n        if num_branches != len(in_channels):\n            error_msg = f'NUM_BRANCHES({num_branches}) <> NUM_INCHANNELS(' \\\n                        f'{len(in_channels)})'\n            raise ValueError(error_msg)\n\n    def _make_one_branch(self,\n                         branch_index,\n                         block,\n                         num_blocks,\n                         num_channels,\n                         stride=1):\n        \"\"\"Build one branch.\"\"\"\n        downsample = None\n        if stride != 1 or \\\n                self.in_channels[branch_index] != \\\n                num_channels[branch_index] * block.expansion:\n            downsample = nn.Sequential(\n                build_conv_layer(\n                    self.conv_cfg,\n                    self.in_channels[branch_index],\n                    num_channels[branch_index] * block.expansion,\n                    kernel_size=1,\n                    stride=stride,\n                    bias=False),\n                build_norm_layer(self.norm_cfg, num_channels[branch_index] *\n                                 block.expansion)[1])\n\n        layers = []\n        layers.append(\n            block(\n                self.in_channels[branch_index],\n                num_channels[branch_index],\n                stride,\n                downsample=downsample,\n                with_cp=self.with_cp,\n                norm_cfg=self.norm_cfg,\n                conv_cfg=self.conv_cfg))\n        self.in_channels[branch_index] = \\\n            num_channels[branch_index] * block.expansion\n        for i in range(1, num_blocks[branch_index]):\n            layers.append(\n                block(\n                    self.in_channels[branch_index],\n                    num_channels[branch_index],\n                    with_cp=self.with_cp,\n                    norm_cfg=self.norm_cfg,\n                    conv_cfg=self.conv_cfg))\n\n        return nn.Sequential(*layers)\n\n    def _make_branches(self, num_branches, block, num_blocks, num_channels):\n        \"\"\"Build multiple branch.\"\"\"\n        branches = []\n\n        for i in range(num_branches):\n            branches.append(\n                self._make_one_branch(i, block, num_blocks, num_channels))\n\n        return nn.ModuleList(branches)\n\n    def _make_fuse_layers(self):\n        \"\"\"Build fuse layer.\"\"\"\n        if self.num_branches == 1:\n            return None\n\n        num_branches = self.num_branches\n        in_channels = self.in_channels\n        fuse_layers = []\n        num_out_branches = num_branches if self.multiscale_output else 1\n        for i in range(num_out_branches):\n            fuse_layer = []\n            for j in range(num_branches):\n                if j > i:\n                    fuse_layer.append(\n                        nn.Sequential(\n                            build_conv_layer(\n                                self.conv_cfg,\n                                in_channels[j],\n                                in_channels[i],\n                                kernel_size=1,\n                                stride=1,\n                                padding=0,\n                                bias=False),\n                            build_norm_layer(self.norm_cfg, in_channels[i])[1],\n                            # we set align_corners=False for HRNet\n                            Upsample(\n                                scale_factor=2**(j - i),\n                                mode='bilinear',\n                                align_corners=False)))\n                elif j == i:\n                    fuse_layer.append(None)\n                else:\n                    conv_downsamples = []\n                    for k in range(i - j):\n                        if k == i - j - 1:\n                            conv_downsamples.append(\n                                nn.Sequential(\n                                    build_conv_layer(\n                                        self.conv_cfg,\n                                        in_channels[j],\n                                        in_channels[i],\n                                        kernel_size=3,\n                                        stride=2,\n                                        padding=1,\n                                        bias=False),\n                                    build_norm_layer(self.norm_cfg,\n                                                     in_channels[i])[1]))\n                        else:\n                            conv_downsamples.append(\n                                nn.Sequential(\n                                    build_conv_layer(\n                                        self.conv_cfg,\n                                        in_channels[j],\n                                        in_channels[j],\n                                        kernel_size=3,\n                                        stride=2,\n                                        padding=1,\n                                        bias=False),\n                                    build_norm_layer(self.norm_cfg,\n                                                     in_channels[j])[1],\n                                    nn.ReLU(inplace=False)))\n                    fuse_layer.append(nn.Sequential(*conv_downsamples))\n            fuse_layers.append(nn.ModuleList(fuse_layer))\n\n        return nn.ModuleList(fuse_layers)\n\n    def forward(self, x):\n        \"\"\"Forward function.\"\"\"\n        if self.num_branches == 1:\n            return [self.branches[0](x[0])]\n\n        for i in range(self.num_branches):\n            x[i] = self.branches[i](x[i])\n\n        x_fuse = []\n        for i in range(len(self.fuse_layers)):\n            y = 0\n            for j in range(self.num_branches):\n                if i == j:\n                    y += x[j]\n                elif j > i:\n                    y = y + resize(\n                        self.fuse_layers[i][j](x[j]),\n                        size=x[i].shape[2:],\n                        mode='bilinear',\n                        align_corners=False)\n                else:\n                    y += self.fuse_layers[i][j](x[j])\n            x_fuse.append(self.relu(y))\n        return x_fuse\n\n\n@BACKBONES.register_module()\nclass HRNet(nn.Module):\n    \"\"\"HRNet backbone.\n\n    High-Resolution Representations for Labeling Pixels and Regions\n    arXiv: https://arxiv.org/abs/1904.04514\n\n    Args:\n        extra (dict): detailed configuration for each stage of HRNet.\n        in_channels (int): Number of input image channels. Normally 3.\n        conv_cfg (dict): dictionary to construct and config conv layer.\n        norm_cfg (dict): dictionary to construct and config norm layer.\n        norm_eval (bool): Whether to set norm layers to eval mode, namely,\n            freeze running stats (mean and var). Note: Effect on Batch Norm\n            and its variants only.\n        with_cp (bool): Use checkpoint or not. Using checkpoint will save some\n            memory while slowing down the training speed.\n        zero_init_residual (bool): whether to use zero init for last norm layer\n            in resblocks to let them behave as identity.\n\n    Example:\n        >>> from mmseg.models import HRNet\n        >>> import torch\n        >>> extra = dict(\n        >>>     stage1=dict(\n        >>>         num_modules=1,\n        >>>         num_branches=1,\n        >>>         block='BOTTLENECK',\n        >>>         num_blocks=(4, ),\n        >>>         num_channels=(64, )),\n        >>>     stage2=dict(\n        >>>         num_modules=1,\n        >>>         num_branches=2,\n        >>>         block='BASIC',\n        >>>         num_blocks=(4, 4),\n        >>>         num_channels=(32, 64)),\n        >>>     stage3=dict(\n        >>>         num_modules=4,\n        >>>         num_branches=3,\n        >>>         block='BASIC',\n        >>>         num_blocks=(4, 4, 4),\n        >>>         num_channels=(32, 64, 128)),\n        >>>     stage4=dict(\n        >>>         num_modules=3,\n        >>>         num_branches=4,\n        >>>         block='BASIC',\n        >>>         num_blocks=(4, 4, 4, 4),\n        >>>         num_channels=(32, 64, 128, 256)))\n        >>> self = HRNet(extra, in_channels=1)\n        >>> self.eval()\n        >>> inputs = torch.rand(1, 1, 32, 32)\n        >>> level_outputs = self.forward(inputs)\n        >>> for level_out in level_outputs:\n        ...     print(tuple(level_out.shape))\n        (1, 32, 8, 8)\n        (1, 64, 4, 4)\n        (1, 128, 2, 2)\n        (1, 256, 1, 1)\n    \"\"\"\n\n    blocks_dict = {'BASIC': BasicBlock, 'BOTTLENECK': Bottleneck}\n\n    def __init__(self,\n                 extra,\n                 in_channels=3,\n                 conv_cfg=None,\n                 norm_cfg=dict(type='BN', requires_grad=True),\n                 norm_eval=False,\n                 with_cp=False,\n                 zero_init_residual=False):\n        super(HRNet, self).__init__()\n        self.extra = extra\n        self.conv_cfg = conv_cfg\n        self.norm_cfg = norm_cfg\n        self.norm_eval = norm_eval\n        self.with_cp = with_cp\n        self.zero_init_residual = zero_init_residual\n\n        # stem net\n        self.norm1_name, norm1 = build_norm_layer(self.norm_cfg, 64, postfix=1)\n        self.norm2_name, norm2 = build_norm_layer(self.norm_cfg, 64, postfix=2)\n\n        self.conv1 = build_conv_layer(\n            self.conv_cfg,\n            in_channels,\n            64,\n            kernel_size=3,\n            stride=2,\n            padding=1,\n            bias=False)\n\n        self.add_module(self.norm1_name, norm1)\n        self.conv2 = build_conv_layer(\n            self.conv_cfg,\n            64,\n            64,\n            kernel_size=3,\n            stride=2,\n            padding=1,\n            bias=False)\n\n        self.add_module(self.norm2_name, norm2)\n        self.relu = nn.ReLU(inplace=True)\n\n        # stage 1\n        self.stage1_cfg = self.extra['stage1']\n        num_channels = self.stage1_cfg['num_channels'][0]\n        block_type = self.stage1_cfg['block']\n        num_blocks = self.stage1_cfg['num_blocks'][0]\n\n        block = self.blocks_dict[block_type]\n        stage1_out_channels = num_channels * block.expansion\n        self.layer1 = self._make_layer(block, 64, num_channels, num_blocks)\n\n        # stage 2\n        self.stage2_cfg = self.extra['stage2']\n        num_channels = self.stage2_cfg['num_channels']\n        block_type = self.stage2_cfg['block']\n\n        block = self.blocks_dict[block_type]\n        num_channels = [channel * block.expansion for channel in num_channels]\n        self.transition1 = self._make_transition_layer([stage1_out_channels],\n                                                       num_channels)\n        self.stage2, pre_stage_channels = self._make_stage(\n            self.stage2_cfg, num_channels)\n\n        # stage 3\n        self.stage3_cfg = self.extra['stage3']\n        num_channels = self.stage3_cfg['num_channels']\n        block_type = self.stage3_cfg['block']\n\n        block = self.blocks_dict[block_type]\n        num_channels = [channel * block.expansion for channel in num_channels]\n        self.transition2 = self._make_transition_layer(pre_stage_channels,\n                                                       num_channels)\n        self.stage3, pre_stage_channels = self._make_stage(\n            self.stage3_cfg, num_channels)\n\n        # stage 4\n        self.stage4_cfg = self.extra['stage4']\n        num_channels = self.stage4_cfg['num_channels']\n        block_type = self.stage4_cfg['block']\n\n        block = self.blocks_dict[block_type]\n        num_channels = [channel * block.expansion for channel in num_channels]\n        self.transition3 = self._make_transition_layer(pre_stage_channels,\n                                                       num_channels)\n        self.stage4, pre_stage_channels = self._make_stage(\n            self.stage4_cfg, num_channels)\n\n    @property\n    def norm1(self):\n        \"\"\"nn.Module: the normalization layer named \"norm1\" \"\"\"\n        return getattr(self, self.norm1_name)\n\n    @property\n    def norm2(self):\n        \"\"\"nn.Module: the normalization layer named \"norm2\" \"\"\"\n        return getattr(self, self.norm2_name)\n\n    def _make_transition_layer(self, num_channels_pre_layer,\n                               num_channels_cur_layer):\n        \"\"\"Make transition layer.\"\"\"\n        num_branches_cur = len(num_channels_cur_layer)\n        num_branches_pre = len(num_channels_pre_layer)\n\n        transition_layers = []\n        for i in range(num_branches_cur):\n            if i < num_branches_pre:\n                if num_channels_cur_layer[i] != num_channels_pre_layer[i]:\n                    transition_layers.append(\n                        nn.Sequential(\n                            build_conv_layer(\n                                self.conv_cfg,\n                                num_channels_pre_layer[i],\n                                num_channels_cur_layer[i],\n                                kernel_size=3,\n                                stride=1,\n                                padding=1,\n                                bias=False),\n                            build_norm_layer(self.norm_cfg,\n                                             num_channels_cur_layer[i])[1],\n                            nn.ReLU(inplace=True)))\n                else:\n                    transition_layers.append(None)\n            else:\n                conv_downsamples = []\n                for j in range(i + 1 - num_branches_pre):\n                    in_channels = num_channels_pre_layer[-1]\n                    out_channels = num_channels_cur_layer[i] \\\n                        if j == i - num_branches_pre else in_channels\n                    conv_downsamples.append(\n                        nn.Sequential(\n                            build_conv_layer(\n                                self.conv_cfg,\n                                in_channels,\n                                out_channels,\n                                kernel_size=3,\n                                stride=2,\n                                padding=1,\n                                bias=False),\n                            build_norm_layer(self.norm_cfg, out_channels)[1],\n                            nn.ReLU(inplace=True)))\n                transition_layers.append(nn.Sequential(*conv_downsamples))\n\n        return nn.ModuleList(transition_layers)\n\n    def _make_layer(self, block, inplanes, planes, blocks, stride=1):\n        \"\"\"Make each layer.\"\"\"\n        downsample = None\n        if stride != 1 or inplanes != planes * block.expansion:\n            downsample = nn.Sequential(\n                build_conv_layer(\n                    self.conv_cfg,\n                    inplanes,\n                    planes * block.expansion,\n                    kernel_size=1,\n                    stride=stride,\n                    bias=False),\n                build_norm_layer(self.norm_cfg, planes * block.expansion)[1])\n\n        layers = []\n        layers.append(\n            block(\n                inplanes,\n                planes,\n                stride,\n                downsample=downsample,\n                with_cp=self.with_cp,\n                norm_cfg=self.norm_cfg,\n                conv_cfg=self.conv_cfg))\n        inplanes = planes * block.expansion\n        for i in range(1, blocks):\n            layers.append(\n                block(\n                    inplanes,\n                    planes,\n                    with_cp=self.with_cp,\n                    norm_cfg=self.norm_cfg,\n                    conv_cfg=self.conv_cfg))\n\n        return nn.Sequential(*layers)\n\n    def _make_stage(self, layer_config, in_channels, multiscale_output=True):\n        \"\"\"Make each stage.\"\"\"\n        num_modules = layer_config['num_modules']\n        num_branches = layer_config['num_branches']\n        num_blocks = layer_config['num_blocks']\n        num_channels = layer_config['num_channels']\n        block = self.blocks_dict[layer_config['block']]\n\n        hr_modules = []\n        for i in range(num_modules):\n            # multi_scale_output is only used for the last module\n            if not multiscale_output and i == num_modules - 1:\n                reset_multiscale_output = False\n            else:\n                reset_multiscale_output = True\n\n            hr_modules.append(\n                HRModule(\n                    num_branches,\n                    block,\n                    num_blocks,\n                    in_channels,\n                    num_channels,\n                    reset_multiscale_output,\n                    with_cp=self.with_cp,\n                    norm_cfg=self.norm_cfg,\n                    conv_cfg=self.conv_cfg))\n\n        return nn.Sequential(*hr_modules), in_channels\n\n    def init_weights(self, pretrained=None):\n        \"\"\"Initialize the weights in backbone.\n\n        Args:\n            pretrained (str, optional): Path to pre-trained weights.\n                Defaults to None.\n        \"\"\"\n        if isinstance(pretrained, str):\n            logger = get_root_logger()\n            load_checkpoint(self, pretrained, strict=False, logger=logger)\n        elif pretrained is None:\n            for m in self.modules():\n                if isinstance(m, nn.Conv2d):\n                    kaiming_init(m)\n                elif isinstance(m, (_BatchNorm, nn.GroupNorm)):\n                    constant_init(m, 1)\n\n            if self.zero_init_residual:\n                for m in self.modules():\n                    if isinstance(m, Bottleneck):\n                        constant_init(m.norm3, 0)\n                    elif isinstance(m, BasicBlock):\n                        constant_init(m.norm2, 0)\n        else:\n            raise TypeError('pretrained must be a str or None')\n\n    def forward(self, x):\n        \"\"\"Forward function.\"\"\"\n\n        x = self.conv1(x)\n        x = self.norm1(x)\n        x = self.relu(x)\n        x = self.conv2(x)\n        x = self.norm2(x)\n        x = self.relu(x)\n        x = self.layer1(x)\n\n        x_list = []\n        for i in range(self.stage2_cfg['num_branches']):\n            if self.transition1[i] is not None:\n                x_list.append(self.transition1[i](x))\n            else:\n                x_list.append(x)\n        y_list = self.stage2(x_list)\n\n        x_list = []\n        for i in range(self.stage3_cfg['num_branches']):\n            if self.transition2[i] is not None:\n                x_list.append(self.transition2[i](y_list[-1]))\n            else:\n                x_list.append(y_list[i])\n        y_list = self.stage3(x_list)\n\n        x_list = []\n        for i in range(self.stage4_cfg['num_branches']):\n            if self.transition3[i] is not None:\n                x_list.append(self.transition3[i](y_list[-1]))\n            else:\n                x_list.append(y_list[i])\n        y_list = self.stage4(x_list)\n\n        return y_list\n\n    def train(self, mode=True):\n        \"\"\"Convert the model into training mode whill keeping the normalization\n        layer freezed.\"\"\"\n        super(HRNet, self).train(mode)\n        if mode and self.norm_eval:\n            for m in self.modules():\n                # trick: eval have effect on BatchNorm only\n                if isinstance(m, _BatchNorm):\n                    m.eval()\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/mmseg/models/backbones/mobilenet_v2.py",
    "content": "import logging\n\nimport torch.nn as nn\nfrom mmcv.cnn import ConvModule, constant_init, kaiming_init\nfrom mmcv.runner import load_checkpoint\nfrom torch.nn.modules.batchnorm import _BatchNorm\n\nfrom ..builder import BACKBONES\nfrom ..utils import InvertedResidual, make_divisible\n\n\n@BACKBONES.register_module()\nclass MobileNetV2(nn.Module):\n    \"\"\"MobileNetV2 backbone.\n\n    Args:\n        widen_factor (float): Width multiplier, multiply number of\n            channels in each layer by this amount. Default: 1.0.\n        strides (Sequence[int], optional): Strides of the first block of each\n            layer. If not specified, default config in ``arch_setting`` will\n            be used.\n        dilations (Sequence[int]): Dilation of each layer.\n        out_indices (None or Sequence[int]): Output from which stages.\n            Default: (7, ).\n        frozen_stages (int): Stages to be frozen (all param fixed).\n            Default: -1, which means not freezing any parameters.\n        conv_cfg (dict): Config dict for convolution layer.\n            Default: None, which means using conv2d.\n        norm_cfg (dict): Config dict for normalization layer.\n            Default: dict(type='BN').\n        act_cfg (dict): Config dict for activation layer.\n            Default: dict(type='ReLU6').\n        norm_eval (bool): Whether to set norm layers to eval mode, namely,\n            freeze running stats (mean and var). Note: Effect on Batch Norm\n            and its variants only. Default: False.\n        with_cp (bool): Use checkpoint or not. Using checkpoint will save some\n            memory while slowing down the training speed. Default: False.\n    \"\"\"\n\n    # Parameters to build layers. 3 parameters are needed to construct a\n    # layer, from left to right: expand_ratio, channel, num_blocks.\n    arch_settings = [[1, 16, 1], [6, 24, 2], [6, 32, 3], [6, 64, 4],\n                     [6, 96, 3], [6, 160, 3], [6, 320, 1]]\n\n    def __init__(self,\n                 widen_factor=1.,\n                 strides=(1, 2, 2, 2, 1, 2, 1),\n                 dilations=(1, 1, 1, 1, 1, 1, 1),\n                 out_indices=(1, 2, 4, 6),\n                 frozen_stages=-1,\n                 conv_cfg=None,\n                 norm_cfg=dict(type='BN'),\n                 act_cfg=dict(type='ReLU6'),\n                 norm_eval=False,\n                 with_cp=False):\n        super(MobileNetV2, self).__init__()\n        self.widen_factor = widen_factor\n        self.strides = strides\n        self.dilations = dilations\n        assert len(strides) == len(dilations) == len(self.arch_settings)\n        self.out_indices = out_indices\n        for index in out_indices:\n            if index not in range(0, 7):\n                raise ValueError('the item in out_indices must in '\n                                 f'range(0, 8). But received {index}')\n\n        if frozen_stages not in range(-1, 7):\n            raise ValueError('frozen_stages must be in range(-1, 7). '\n                             f'But received {frozen_stages}')\n        self.out_indices = out_indices\n        self.frozen_stages = frozen_stages\n        self.conv_cfg = conv_cfg\n        self.norm_cfg = norm_cfg\n        self.act_cfg = act_cfg\n        self.norm_eval = norm_eval\n        self.with_cp = with_cp\n\n        self.in_channels = make_divisible(32 * widen_factor, 8)\n\n        self.conv1 = ConvModule(\n            in_channels=3,\n            out_channels=self.in_channels,\n            kernel_size=3,\n            stride=2,\n            padding=1,\n            conv_cfg=self.conv_cfg,\n            norm_cfg=self.norm_cfg,\n            act_cfg=self.act_cfg)\n\n        self.layers = []\n\n        for i, layer_cfg in enumerate(self.arch_settings):\n            expand_ratio, channel, num_blocks = layer_cfg\n            stride = self.strides[i]\n            dilation = self.dilations[i]\n            out_channels = make_divisible(channel * widen_factor, 8)\n            inverted_res_layer = self.make_layer(\n                out_channels=out_channels,\n                num_blocks=num_blocks,\n                stride=stride,\n                dilation=dilation,\n                expand_ratio=expand_ratio)\n            layer_name = f'layer{i + 1}'\n            self.add_module(layer_name, inverted_res_layer)\n            self.layers.append(layer_name)\n\n    def make_layer(self, out_channels, num_blocks, stride, dilation,\n                   expand_ratio):\n        \"\"\"Stack InvertedResidual blocks to build a layer for MobileNetV2.\n\n        Args:\n            out_channels (int): out_channels of block.\n            num_blocks (int): Number of blocks.\n            stride (int): Stride of the first block.\n            dilation (int): Dilation of the first block.\n            expand_ratio (int): Expand the number of channels of the\n                hidden layer in InvertedResidual by this ratio.\n        \"\"\"\n        layers = []\n        for i in range(num_blocks):\n            layers.append(\n                InvertedResidual(\n                    self.in_channels,\n                    out_channels,\n                    stride if i == 0 else 1,\n                    expand_ratio=expand_ratio,\n                    dilation=dilation if i == 0 else 1,\n                    conv_cfg=self.conv_cfg,\n                    norm_cfg=self.norm_cfg,\n                    act_cfg=self.act_cfg,\n                    with_cp=self.with_cp))\n            self.in_channels = out_channels\n\n        return nn.Sequential(*layers)\n\n    def init_weights(self, pretrained=None):\n        if isinstance(pretrained, str):\n            logger = logging.getLogger()\n            load_checkpoint(self, pretrained, strict=False, logger=logger)\n        elif pretrained is None:\n            for m in self.modules():\n                if isinstance(m, nn.Conv2d):\n                    kaiming_init(m)\n                elif isinstance(m, (_BatchNorm, nn.GroupNorm)):\n                    constant_init(m, 1)\n        else:\n            raise TypeError('pretrained must be a str or None')\n\n    def forward(self, x):\n        x = self.conv1(x)\n\n        outs = []\n        for i, layer_name in enumerate(self.layers):\n            layer = getattr(self, layer_name)\n            x = layer(x)\n            if i in self.out_indices:\n                outs.append(x)\n\n        if len(outs) == 1:\n            return outs[0]\n        else:\n            return tuple(outs)\n\n    def _freeze_stages(self):\n        if self.frozen_stages >= 0:\n            for param in self.conv1.parameters():\n                param.requires_grad = False\n        for i in range(1, self.frozen_stages + 1):\n            layer = getattr(self, f'layer{i}')\n            layer.eval()\n            for param in layer.parameters():\n                param.requires_grad = False\n\n    def train(self, mode=True):\n        super(MobileNetV2, self).train(mode)\n        self._freeze_stages()\n        if mode and self.norm_eval:\n            for m in self.modules():\n                if isinstance(m, _BatchNorm):\n                    m.eval()\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/mmseg/models/backbones/mobilenet_v3.py",
    "content": "import logging\n\nimport mmcv\nimport torch.nn as nn\nfrom mmcv.cnn import ConvModule, constant_init, kaiming_init\nfrom mmcv.cnn.bricks import Conv2dAdaptivePadding\nfrom mmcv.runner import load_checkpoint\nfrom torch.nn.modules.batchnorm import _BatchNorm\n\nfrom ..builder import BACKBONES\nfrom ..utils import InvertedResidualV3 as InvertedResidual\n\n\n@BACKBONES.register_module()\nclass MobileNetV3(nn.Module):\n    \"\"\"MobileNetV3 backbone.\n\n    This backbone is the improved implementation of `Searching for MobileNetV3\n    <https://ieeexplore.ieee.org/document/9008835>`_.\n\n    Args:\n        arch (str): Architechture of mobilnetv3, from {'small', 'large'}.\n            Default: 'small'.\n        conv_cfg (dict): Config dict for convolution layer.\n            Default: None, which means using conv2d.\n        norm_cfg (dict): Config dict for normalization layer.\n            Default: dict(type='BN').\n        out_indices (tuple[int]): Output from which layer.\n            Default: (0, 1, 12).\n        frozen_stages (int): Stages to be frozen (all param fixed).\n            Defualt: -1, which means not freezing any parameters.\n        norm_eval (bool): Whether to set norm layers to eval mode, namely,\n            freeze running stats (mean and var). Note: Effect on Batch Norm\n            and its variants only. Default: False.\n        with_cp (bool): Use checkpoint or not. Using checkpoint will save\n            some memory while slowing down the training speed.\n            Defualt: False.\n    \"\"\"\n    # Parameters to build each block:\n    #     [kernel size, mid channels, out channels, with_se, act type, stride]\n    arch_settings = {\n        'small': [[3, 16, 16, True, 'ReLU', 2],  # block0 layer1 os=4\n                  [3, 72, 24, False, 'ReLU', 2],  # block1 layer2 os=8\n                  [3, 88, 24, False, 'ReLU', 1],\n                  [5, 96, 40, True, 'HSwish', 2],  # block2 layer4 os=16\n                  [5, 240, 40, True, 'HSwish', 1],\n                  [5, 240, 40, True, 'HSwish', 1],\n                  [5, 120, 48, True, 'HSwish', 1],  # block3 layer7 os=16\n                  [5, 144, 48, True, 'HSwish', 1],\n                  [5, 288, 96, True, 'HSwish', 2],  # block4 layer9 os=32\n                  [5, 576, 96, True, 'HSwish', 1],\n                  [5, 576, 96, True, 'HSwish', 1]],\n        'large': [[3, 16, 16, False, 'ReLU', 1],  # block0 layer1 os=2\n                  [3, 64, 24, False, 'ReLU', 2],  # block1 layer2 os=4\n                  [3, 72, 24, False, 'ReLU', 1],\n                  [5, 72, 40, True, 'ReLU', 2],  # block2 layer4 os=8\n                  [5, 120, 40, True, 'ReLU', 1],\n                  [5, 120, 40, True, 'ReLU', 1],\n                  [3, 240, 80, False, 'HSwish', 2],  # block3 layer7 os=16\n                  [3, 200, 80, False, 'HSwish', 1],\n                  [3, 184, 80, False, 'HSwish', 1],\n                  [3, 184, 80, False, 'HSwish', 1],\n                  [3, 480, 112, True, 'HSwish', 1],  # block4 layer11 os=16\n                  [3, 672, 112, True, 'HSwish', 1],\n                  [5, 672, 160, True, 'HSwish', 2],  # block5 layer13 os=32\n                  [5, 960, 160, True, 'HSwish', 1],\n                  [5, 960, 160, True, 'HSwish', 1]]\n    }  # yapf: disable\n\n    def __init__(self,\n                 arch='small',\n                 conv_cfg=None,\n                 norm_cfg=dict(type='BN'),\n                 out_indices=(0, 1, 12),\n                 frozen_stages=-1,\n                 reduction_factor=1,\n                 norm_eval=False,\n                 with_cp=False):\n        super(MobileNetV3, self).__init__()\n        assert arch in self.arch_settings\n        assert isinstance(reduction_factor, int) and reduction_factor > 0\n        assert mmcv.is_tuple_of(out_indices, int)\n        for index in out_indices:\n            if index not in range(0, len(self.arch_settings[arch]) + 2):\n                raise ValueError(\n                    'the item in out_indices must in '\n                    f'range(0, {len(self.arch_settings[arch])+2}). '\n                    f'But received {index}')\n\n        if frozen_stages not in range(-1, len(self.arch_settings[arch]) + 2):\n            raise ValueError('frozen_stages must be in range(-1, '\n                             f'{len(self.arch_settings[arch])+2}). '\n                             f'But received {frozen_stages}')\n        self.arch = arch\n        self.conv_cfg = conv_cfg\n        self.norm_cfg = norm_cfg\n        self.out_indices = out_indices\n        self.frozen_stages = frozen_stages\n        self.reduction_factor = reduction_factor\n        self.norm_eval = norm_eval\n        self.with_cp = with_cp\n        self.layers = self._make_layer()\n\n    def _make_layer(self):\n        layers = []\n\n        # build the first layer (layer0)\n        in_channels = 16\n        layer = ConvModule(\n            in_channels=3,\n            out_channels=in_channels,\n            kernel_size=3,\n            stride=2,\n            padding=1,\n            conv_cfg=dict(type='Conv2dAdaptivePadding'),\n            norm_cfg=self.norm_cfg,\n            act_cfg=dict(type='HSwish'))\n        self.add_module('layer0', layer)\n        layers.append('layer0')\n\n        layer_setting = self.arch_settings[self.arch]\n        for i, params in enumerate(layer_setting):\n            (kernel_size, mid_channels, out_channels, with_se, act,\n             stride) = params\n\n            if self.arch == 'large' and i >= 12 or self.arch == 'small' and \\\n                    i >= 8:\n                mid_channels = mid_channels // self.reduction_factor\n                out_channels = out_channels // self.reduction_factor\n\n            if with_se:\n                se_cfg = dict(\n                    channels=mid_channels,\n                    ratio=4,\n                    act_cfg=(dict(type='ReLU'),\n                             dict(type='HSigmoid', bias=3.0, divisor=6.0)))\n            else:\n                se_cfg = None\n\n            layer = InvertedResidual(\n                in_channels=in_channels,\n                out_channels=out_channels,\n                mid_channels=mid_channels,\n                kernel_size=kernel_size,\n                stride=stride,\n                se_cfg=se_cfg,\n                with_expand_conv=(in_channels != mid_channels),\n                conv_cfg=self.conv_cfg,\n                norm_cfg=self.norm_cfg,\n                act_cfg=dict(type=act),\n                with_cp=self.with_cp)\n            in_channels = out_channels\n            layer_name = 'layer{}'.format(i + 1)\n            self.add_module(layer_name, layer)\n            layers.append(layer_name)\n\n        # build the last layer\n        # block5 layer12 os=32 for small model\n        # block6 layer16 os=32 for large model\n        layer = ConvModule(\n            in_channels=in_channels,\n            out_channels=576 if self.arch == 'small' else 960,\n            kernel_size=1,\n            stride=1,\n            dilation=4,\n            padding=0,\n            conv_cfg=self.conv_cfg,\n            norm_cfg=self.norm_cfg,\n            act_cfg=dict(type='HSwish'))\n        layer_name = 'layer{}'.format(len(layer_setting) + 1)\n        self.add_module(layer_name, layer)\n        layers.append(layer_name)\n\n        # next, convert backbone MobileNetV3 to a semantic segmentation version\n        if self.arch == 'small':\n            self.layer4.depthwise_conv.conv.stride = (1, 1)\n            self.layer9.depthwise_conv.conv.stride = (1, 1)\n            for i in range(4, len(layers)):\n                layer = getattr(self, layers[i])\n                if isinstance(layer, InvertedResidual):\n                    modified_module = layer.depthwise_conv.conv\n                else:\n                    modified_module = layer.conv\n\n                if i < 9:\n                    modified_module.dilation = (2, 2)\n                    pad = 2\n                else:\n                    modified_module.dilation = (4, 4)\n                    pad = 4\n\n                if not isinstance(modified_module, Conv2dAdaptivePadding):\n                    # Adjust padding\n                    pad *= (modified_module.kernel_size[0] - 1) // 2\n                    modified_module.padding = (pad, pad)\n        else:\n            self.layer7.depthwise_conv.conv.stride = (1, 1)\n            self.layer13.depthwise_conv.conv.stride = (1, 1)\n            for i in range(7, len(layers)):\n                layer = getattr(self, layers[i])\n                if isinstance(layer, InvertedResidual):\n                    modified_module = layer.depthwise_conv.conv\n                else:\n                    modified_module = layer.conv\n\n                if i < 13:\n                    modified_module.dilation = (2, 2)\n                    pad = 2\n                else:\n                    modified_module.dilation = (4, 4)\n                    pad = 4\n\n                if not isinstance(modified_module, Conv2dAdaptivePadding):\n                    # Adjust padding\n                    pad *= (modified_module.kernel_size[0] - 1) // 2\n                    modified_module.padding = (pad, pad)\n\n        return layers\n\n    def init_weights(self, pretrained=None):\n        if isinstance(pretrained, str):\n            logger = logging.getLogger()\n            load_checkpoint(self, pretrained, strict=False, logger=logger)\n        elif pretrained is None:\n            for m in self.modules():\n                if isinstance(m, nn.Conv2d):\n                    kaiming_init(m)\n                elif isinstance(m, nn.BatchNorm2d):\n                    constant_init(m, 1)\n        else:\n            raise TypeError('pretrained must be a str or None')\n\n    def forward(self, x):\n        outs = []\n        for i, layer_name in enumerate(self.layers):\n            layer = getattr(self, layer_name)\n            x = layer(x)\n            if i in self.out_indices:\n                outs.append(x)\n        return outs\n\n    def _freeze_stages(self):\n        for i in range(self.frozen_stages + 1):\n            layer = getattr(self, f'layer{i}')\n            layer.eval()\n            for param in layer.parameters():\n                param.requires_grad = False\n\n    def train(self, mode=True):\n        super(MobileNetV3, self).train(mode)\n        self._freeze_stages()\n        if mode and self.norm_eval:\n            for m in self.modules():\n                if isinstance(m, _BatchNorm):\n                    m.eval()\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/mmseg/models/backbones/resnest.py",
    "content": "import math\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.utils.checkpoint as cp\nfrom mmcv.cnn import build_conv_layer, build_norm_layer\n\nfrom ..builder import BACKBONES\nfrom ..utils import ResLayer\nfrom .resnet import Bottleneck as _Bottleneck\nfrom .resnet import ResNetV1d\n\n\nclass RSoftmax(nn.Module):\n    \"\"\"Radix Softmax module in ``SplitAttentionConv2d``.\n\n    Args:\n        radix (int): Radix of input.\n        groups (int): Groups of input.\n    \"\"\"\n\n    def __init__(self, radix, groups):\n        super().__init__()\n        self.radix = radix\n        self.groups = groups\n\n    def forward(self, x):\n        batch = x.size(0)\n        if self.radix > 1:\n            x = x.view(batch, self.groups, self.radix, -1).transpose(1, 2)\n            x = F.softmax(x, dim=1)\n            x = x.reshape(batch, -1)\n        else:\n            x = torch.sigmoid(x)\n        return x\n\n\nclass SplitAttentionConv2d(nn.Module):\n    \"\"\"Split-Attention Conv2d in ResNeSt.\n\n    Args:\n        in_channels (int): Same as nn.Conv2d.\n        out_channels (int): Same as nn.Conv2d.\n        kernel_size (int | tuple[int]): Same as nn.Conv2d.\n        stride (int | tuple[int]): Same as nn.Conv2d.\n        padding (int | tuple[int]): Same as nn.Conv2d.\n        dilation (int | tuple[int]): Same as nn.Conv2d.\n        groups (int): Same as nn.Conv2d.\n        radix (int): Radix of SpltAtConv2d. Default: 2\n        reduction_factor (int): Reduction factor of inter_channels. Default: 4.\n        conv_cfg (dict): Config dict for convolution layer. Default: None,\n            which means using conv2d.\n        norm_cfg (dict): Config dict for normalization layer. Default: None.\n        dcn (dict): Config dict for DCN. Default: None.\n    \"\"\"\n\n    def __init__(self,\n                 in_channels,\n                 channels,\n                 kernel_size,\n                 stride=1,\n                 padding=0,\n                 dilation=1,\n                 groups=1,\n                 radix=2,\n                 reduction_factor=4,\n                 conv_cfg=None,\n                 norm_cfg=dict(type='BN'),\n                 dcn=None):\n        super(SplitAttentionConv2d, self).__init__()\n        inter_channels = max(in_channels * radix // reduction_factor, 32)\n        self.radix = radix\n        self.groups = groups\n        self.channels = channels\n        self.with_dcn = dcn is not None\n        self.dcn = dcn\n        fallback_on_stride = False\n        if self.with_dcn:\n            fallback_on_stride = self.dcn.pop('fallback_on_stride', False)\n        if self.with_dcn and not fallback_on_stride:\n            assert conv_cfg is None, 'conv_cfg must be None for DCN'\n            conv_cfg = dcn\n        self.conv = build_conv_layer(\n            conv_cfg,\n            in_channels,\n            channels * radix,\n            kernel_size,\n            stride=stride,\n            padding=padding,\n            dilation=dilation,\n            groups=groups * radix,\n            bias=False)\n        self.norm0_name, norm0 = build_norm_layer(\n            norm_cfg, channels * radix, postfix=0)\n        self.add_module(self.norm0_name, norm0)\n        self.relu = nn.ReLU(inplace=True)\n        self.fc1 = build_conv_layer(\n            None, channels, inter_channels, 1, groups=self.groups)\n        self.norm1_name, norm1 = build_norm_layer(\n            norm_cfg, inter_channels, postfix=1)\n        self.add_module(self.norm1_name, norm1)\n        self.fc2 = build_conv_layer(\n            None, inter_channels, channels * radix, 1, groups=self.groups)\n        self.rsoftmax = RSoftmax(radix, groups)\n\n    @property\n    def norm0(self):\n        \"\"\"nn.Module: the normalization layer named \"norm0\" \"\"\"\n        return getattr(self, self.norm0_name)\n\n    @property\n    def norm1(self):\n        \"\"\"nn.Module: the normalization layer named \"norm1\" \"\"\"\n        return getattr(self, self.norm1_name)\n\n    def forward(self, x):\n        x = self.conv(x)\n        x = self.norm0(x)\n        x = self.relu(x)\n\n        batch, rchannel = x.shape[:2]\n        batch = x.size(0)\n        if self.radix > 1:\n            splits = x.view(batch, self.radix, -1, *x.shape[2:])\n            gap = splits.sum(dim=1)\n        else:\n            gap = x\n        gap = F.adaptive_avg_pool2d(gap, 1)\n        gap = self.fc1(gap)\n\n        gap = self.norm1(gap)\n        gap = self.relu(gap)\n\n        atten = self.fc2(gap)\n        atten = self.rsoftmax(atten).view(batch, -1, 1, 1)\n\n        if self.radix > 1:\n            attens = atten.view(batch, self.radix, -1, *atten.shape[2:])\n            out = torch.sum(attens * splits, dim=1)\n        else:\n            out = atten * x\n        return out.contiguous()\n\n\nclass Bottleneck(_Bottleneck):\n    \"\"\"Bottleneck block for ResNeSt.\n\n    Args:\n        inplane (int): Input planes of this block.\n        planes (int): Middle planes of this block.\n        groups (int): Groups of conv2.\n        width_per_group (int): Width per group of conv2. 64x4d indicates\n            ``groups=64, width_per_group=4`` and 32x8d indicates\n            ``groups=32, width_per_group=8``.\n        radix (int): Radix of SpltAtConv2d. Default: 2\n        reduction_factor (int): Reduction factor of inter_channels in\n            SplitAttentionConv2d. Default: 4.\n        avg_down_stride (bool): Whether to use average pool for stride in\n            Bottleneck. Default: True.\n        kwargs (dict): Key word arguments for base class.\n    \"\"\"\n    expansion = 4\n\n    def __init__(self,\n                 inplanes,\n                 planes,\n                 groups=1,\n                 base_width=4,\n                 base_channels=64,\n                 radix=2,\n                 reduction_factor=4,\n                 avg_down_stride=True,\n                 **kwargs):\n        \"\"\"Bottleneck block for ResNeSt.\"\"\"\n        super(Bottleneck, self).__init__(inplanes, planes, **kwargs)\n\n        if groups == 1:\n            width = self.planes\n        else:\n            width = math.floor(self.planes *\n                               (base_width / base_channels)) * groups\n\n        self.avg_down_stride = avg_down_stride and self.conv2_stride > 1\n\n        self.norm1_name, norm1 = build_norm_layer(\n            self.norm_cfg, width, postfix=1)\n        self.norm3_name, norm3 = build_norm_layer(\n            self.norm_cfg, self.planes * self.expansion, postfix=3)\n\n        self.conv1 = build_conv_layer(\n            self.conv_cfg,\n            self.inplanes,\n            width,\n            kernel_size=1,\n            stride=self.conv1_stride,\n            bias=False)\n        self.add_module(self.norm1_name, norm1)\n        self.with_modulated_dcn = False\n        self.conv2 = SplitAttentionConv2d(\n            width,\n            width,\n            kernel_size=3,\n            stride=1 if self.avg_down_stride else self.conv2_stride,\n            padding=self.dilation,\n            dilation=self.dilation,\n            groups=groups,\n            radix=radix,\n            reduction_factor=reduction_factor,\n            conv_cfg=self.conv_cfg,\n            norm_cfg=self.norm_cfg,\n            dcn=self.dcn)\n        delattr(self, self.norm2_name)\n\n        if self.avg_down_stride:\n            self.avd_layer = nn.AvgPool2d(3, self.conv2_stride, padding=1)\n\n        self.conv3 = build_conv_layer(\n            self.conv_cfg,\n            width,\n            self.planes * self.expansion,\n            kernel_size=1,\n            bias=False)\n        self.add_module(self.norm3_name, norm3)\n\n    def forward(self, x):\n\n        def _inner_forward(x):\n            identity = x\n\n            out = self.conv1(x)\n            out = self.norm1(out)\n            out = self.relu(out)\n\n            if self.with_plugins:\n                out = self.forward_plugin(out, self.after_conv1_plugin_names)\n\n            out = self.conv2(out)\n\n            if self.avg_down_stride:\n                out = self.avd_layer(out)\n\n            if self.with_plugins:\n                out = self.forward_plugin(out, self.after_conv2_plugin_names)\n\n            out = self.conv3(out)\n            out = self.norm3(out)\n\n            if self.with_plugins:\n                out = self.forward_plugin(out, self.after_conv3_plugin_names)\n\n            if self.downsample is not None:\n                identity = self.downsample(x)\n\n            out += identity\n\n            return out\n\n        if self.with_cp and x.requires_grad:\n            out = cp.checkpoint(_inner_forward, x)\n        else:\n            out = _inner_forward(x)\n\n        out = self.relu(out)\n\n        return out\n\n\n@BACKBONES.register_module()\nclass ResNeSt(ResNetV1d):\n    \"\"\"ResNeSt backbone.\n\n    Args:\n        groups (int): Number of groups of Bottleneck. Default: 1\n        base_width (int): Base width of Bottleneck. Default: 4\n        radix (int): Radix of SpltAtConv2d. Default: 2\n        reduction_factor (int): Reduction factor of inter_channels in\n            SplitAttentionConv2d. Default: 4.\n        avg_down_stride (bool): Whether to use average pool for stride in\n            Bottleneck. Default: True.\n        kwargs (dict): Keyword arguments for ResNet.\n    \"\"\"\n\n    arch_settings = {\n        50: (Bottleneck, (3, 4, 6, 3)),\n        101: (Bottleneck, (3, 4, 23, 3)),\n        152: (Bottleneck, (3, 8, 36, 3)),\n        200: (Bottleneck, (3, 24, 36, 3))\n    }\n\n    def __init__(self,\n                 groups=1,\n                 base_width=4,\n                 radix=2,\n                 reduction_factor=4,\n                 avg_down_stride=True,\n                 **kwargs):\n        self.groups = groups\n        self.base_width = base_width\n        self.radix = radix\n        self.reduction_factor = reduction_factor\n        self.avg_down_stride = avg_down_stride\n        super(ResNeSt, self).__init__(**kwargs)\n\n    def make_res_layer(self, **kwargs):\n        \"\"\"Pack all blocks in a stage into a ``ResLayer``.\"\"\"\n        return ResLayer(\n            groups=self.groups,\n            base_width=self.base_width,\n            base_channels=self.base_channels,\n            radix=self.radix,\n            reduction_factor=self.reduction_factor,\n            avg_down_stride=self.avg_down_stride,\n            **kwargs)\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/mmseg/models/backbones/resnet.py",
    "content": "import torch.nn as nn\nimport torch.utils.checkpoint as cp\nfrom mmcv.cnn import (build_conv_layer, build_norm_layer, build_plugin_layer,\n                      constant_init, kaiming_init)\nfrom mmcv.runner import load_checkpoint\nfrom mmcv.utils.parrots_wrapper import _BatchNorm\n\nfrom mmseg.utils import get_root_logger\nfrom ..builder import BACKBONES\nfrom ..utils import ResLayer\n\n\nclass BasicBlock(nn.Module):\n    \"\"\"Basic block for ResNet.\"\"\"\n\n    expansion = 1\n\n    def __init__(self,\n                 inplanes,\n                 planes,\n                 stride=1,\n                 dilation=1,\n                 downsample=None,\n                 style='pytorch',\n                 with_cp=False,\n                 conv_cfg=None,\n                 norm_cfg=dict(type='BN'),\n                 dcn=None,\n                 plugins=None):\n        super(BasicBlock, self).__init__()\n        assert dcn is None, 'Not implemented yet.'\n        assert plugins is None, 'Not implemented yet.'\n\n        self.norm1_name, norm1 = build_norm_layer(norm_cfg, planes, postfix=1)\n        self.norm2_name, norm2 = build_norm_layer(norm_cfg, planes, postfix=2)\n\n        self.conv1 = build_conv_layer(\n            conv_cfg,\n            inplanes,\n            planes,\n            3,\n            stride=stride,\n            padding=dilation,\n            dilation=dilation,\n            bias=False)\n        self.add_module(self.norm1_name, norm1)\n        self.conv2 = build_conv_layer(\n            conv_cfg, planes, planes, 3, padding=1, bias=False)\n        self.add_module(self.norm2_name, norm2)\n\n        self.relu = nn.ReLU(inplace=True)\n        self.downsample = downsample\n        self.stride = stride\n        self.dilation = dilation\n        self.with_cp = with_cp\n\n    @property\n    def norm1(self):\n        \"\"\"nn.Module: normalization layer after the first convolution layer\"\"\"\n        return getattr(self, self.norm1_name)\n\n    @property\n    def norm2(self):\n        \"\"\"nn.Module: normalization layer after the second convolution layer\"\"\"\n        return getattr(self, self.norm2_name)\n\n    def forward(self, x):\n        \"\"\"Forward function.\"\"\"\n\n        def _inner_forward(x):\n            identity = x\n\n            out = self.conv1(x)\n            out = self.norm1(out)\n            out = self.relu(out)\n\n            out = self.conv2(out)\n            out = self.norm2(out)\n\n            if self.downsample is not None:\n                identity = self.downsample(x)\n\n            out += identity\n\n            return out\n\n        if self.with_cp and x.requires_grad:\n            out = cp.checkpoint(_inner_forward, x)\n        else:\n            out = _inner_forward(x)\n\n        out = self.relu(out)\n\n        return out\n\n\nclass Bottleneck(nn.Module):\n    \"\"\"Bottleneck block for ResNet.\n\n    If style is \"pytorch\", the stride-two layer is the 3x3 conv layer, if it is\n    \"caffe\", the stride-two layer is the first 1x1 conv layer.\n    \"\"\"\n\n    expansion = 4\n\n    def __init__(self,\n                 inplanes,\n                 planes,\n                 stride=1,\n                 dilation=1,\n                 downsample=None,\n                 style='pytorch',\n                 with_cp=False,\n                 conv_cfg=None,\n                 norm_cfg=dict(type='BN'),\n                 dcn=None,\n                 plugins=None):\n        super(Bottleneck, self).__init__()\n        assert style in ['pytorch', 'caffe']\n        assert dcn is None or isinstance(dcn, dict)\n        assert plugins is None or isinstance(plugins, list)\n        if plugins is not None:\n            allowed_position = ['after_conv1', 'after_conv2', 'after_conv3']\n            assert all(p['position'] in allowed_position for p in plugins)\n\n        self.inplanes = inplanes\n        self.planes = planes\n        self.stride = stride\n        self.dilation = dilation\n        self.style = style\n        self.with_cp = with_cp\n        self.conv_cfg = conv_cfg\n        self.norm_cfg = norm_cfg\n        self.dcn = dcn\n        self.with_dcn = dcn is not None\n        self.plugins = plugins\n        self.with_plugins = plugins is not None\n\n        if self.with_plugins:\n            # collect plugins for conv1/conv2/conv3\n            self.after_conv1_plugins = [\n                plugin['cfg'] for plugin in plugins\n                if plugin['position'] == 'after_conv1'\n            ]\n            self.after_conv2_plugins = [\n                plugin['cfg'] for plugin in plugins\n                if plugin['position'] == 'after_conv2'\n            ]\n            self.after_conv3_plugins = [\n                plugin['cfg'] for plugin in plugins\n                if plugin['position'] == 'after_conv3'\n            ]\n\n        if self.style == 'pytorch':\n            self.conv1_stride = 1\n            self.conv2_stride = stride\n        else:\n            self.conv1_stride = stride\n            self.conv2_stride = 1\n\n        self.norm1_name, norm1 = build_norm_layer(norm_cfg, planes, postfix=1)\n        self.norm2_name, norm2 = build_norm_layer(norm_cfg, planes, postfix=2)\n        self.norm3_name, norm3 = build_norm_layer(\n            norm_cfg, planes * self.expansion, postfix=3)\n\n        self.conv1 = build_conv_layer(\n            conv_cfg,\n            inplanes,\n            planes,\n            kernel_size=1,\n            stride=self.conv1_stride,\n            bias=False)\n        self.add_module(self.norm1_name, norm1)\n        fallback_on_stride = False\n        if self.with_dcn:\n            fallback_on_stride = dcn.pop('fallback_on_stride', False)\n        if not self.with_dcn or fallback_on_stride:\n            self.conv2 = build_conv_layer(\n                conv_cfg,\n                planes,\n                planes,\n                kernel_size=3,\n                stride=self.conv2_stride,\n                padding=dilation,\n                dilation=dilation,\n                bias=False)\n        else:\n            assert self.conv_cfg is None, 'conv_cfg must be None for DCN'\n            self.conv2 = build_conv_layer(\n                dcn,\n                planes,\n                planes,\n                kernel_size=3,\n                stride=self.conv2_stride,\n                padding=dilation,\n                dilation=dilation,\n                bias=False)\n\n        self.add_module(self.norm2_name, norm2)\n        self.conv3 = build_conv_layer(\n            conv_cfg,\n            planes,\n            planes * self.expansion,\n            kernel_size=1,\n            bias=False)\n        self.add_module(self.norm3_name, norm3)\n\n        self.relu = nn.ReLU(inplace=True)\n        self.downsample = downsample\n\n        if self.with_plugins:\n            self.after_conv1_plugin_names = self.make_block_plugins(\n                planes, self.after_conv1_plugins)\n            self.after_conv2_plugin_names = self.make_block_plugins(\n                planes, self.after_conv2_plugins)\n            self.after_conv3_plugin_names = self.make_block_plugins(\n                planes * self.expansion, self.after_conv3_plugins)\n\n    def make_block_plugins(self, in_channels, plugins):\n        \"\"\"make plugins for block.\n\n        Args:\n            in_channels (int): Input channels of plugin.\n            plugins (list[dict]): List of plugins cfg to build.\n\n        Returns:\n            list[str]: List of the names of plugin.\n        \"\"\"\n        assert isinstance(plugins, list)\n        plugin_names = []\n        for plugin in plugins:\n            plugin = plugin.copy()\n            name, layer = build_plugin_layer(\n                plugin,\n                in_channels=in_channels,\n                postfix=plugin.pop('postfix', ''))\n            assert not hasattr(self, name), f'duplicate plugin {name}'\n            self.add_module(name, layer)\n            plugin_names.append(name)\n        return plugin_names\n\n    def forward_plugin(self, x, plugin_names):\n        \"\"\"Forward function for plugins.\"\"\"\n        out = x\n        for name in plugin_names:\n            out = getattr(self, name)(x)\n        return out\n\n    @property\n    def norm1(self):\n        \"\"\"nn.Module: normalization layer after the first convolution layer\"\"\"\n        return getattr(self, self.norm1_name)\n\n    @property\n    def norm2(self):\n        \"\"\"nn.Module: normalization layer after the second convolution layer\"\"\"\n        return getattr(self, self.norm2_name)\n\n    @property\n    def norm3(self):\n        \"\"\"nn.Module: normalization layer after the third convolution layer\"\"\"\n        return getattr(self, self.norm3_name)\n\n    def forward(self, x):\n        \"\"\"Forward function.\"\"\"\n\n        def _inner_forward(x):\n            identity = x\n\n            out = self.conv1(x)\n            out = self.norm1(out)\n            out = self.relu(out)\n\n            if self.with_plugins:\n                out = self.forward_plugin(out, self.after_conv1_plugin_names)\n\n            out = self.conv2(out)\n            out = self.norm2(out)\n            out = self.relu(out)\n\n            if self.with_plugins:\n                out = self.forward_plugin(out, self.after_conv2_plugin_names)\n\n            out = self.conv3(out)\n            out = self.norm3(out)\n\n            if self.with_plugins:\n                out = self.forward_plugin(out, self.after_conv3_plugin_names)\n\n            if self.downsample is not None:\n                identity = self.downsample(x)\n\n            out += identity\n\n            return out\n\n        if self.with_cp and x.requires_grad:\n            out = cp.checkpoint(_inner_forward, x)\n        else:\n            out = _inner_forward(x)\n\n        out = self.relu(out)\n\n        return out\n\n\n@BACKBONES.register_module()\nclass ResNet(nn.Module):\n    \"\"\"ResNet backbone.\n\n    Args:\n        depth (int): Depth of resnet, from {18, 34, 50, 101, 152}.\n        in_channels (int): Number of input image channels. Default\" 3.\n        stem_channels (int): Number of stem channels. Default: 64.\n        base_channels (int): Number of base channels of res layer. Default: 64.\n        num_stages (int): Resnet stages, normally 4.\n        strides (Sequence[int]): Strides of the first block of each stage.\n        dilations (Sequence[int]): Dilation of each stage.\n        out_indices (Sequence[int]): Output from which stages.\n        style (str): `pytorch` or `caffe`. If set to \"pytorch\", the stride-two\n            layer is the 3x3 conv layer, otherwise the stride-two layer is\n            the first 1x1 conv layer.\n        deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv\n        avg_down (bool): Use AvgPool instead of stride conv when\n            downsampling in the bottleneck.\n        frozen_stages (int): Stages to be frozen (stop grad and set eval mode).\n            -1 means not freezing any parameters.\n        norm_cfg (dict): Dictionary to construct and config norm layer.\n        norm_eval (bool): Whether to set norm layers to eval mode, namely,\n            freeze running stats (mean and var). Note: Effect on Batch Norm\n            and its variants only.\n        plugins (list[dict]): List of plugins for stages, each dict contains:\n\n            - cfg (dict, required): Cfg dict to build plugin.\n\n            - position (str, required): Position inside block to insert plugin,\n            options: 'after_conv1', 'after_conv2', 'after_conv3'.\n\n            - stages (tuple[bool], optional): Stages to apply plugin, length\n            should be same as 'num_stages'\n        multi_grid (Sequence[int]|None): Multi grid dilation rates of last\n            stage. Default: None\n        contract_dilation (bool): Whether contract first dilation of each layer\n            Default: False\n        with_cp (bool): Use checkpoint or not. Using checkpoint will save some\n            memory while slowing down the training speed.\n        zero_init_residual (bool): Whether to use zero init for last norm layer\n            in resblocks to let them behave as identity.\n\n    Example:\n        >>> from mmseg.models import ResNet\n        >>> import torch\n        >>> self = ResNet(depth=18)\n        >>> self.eval()\n        >>> inputs = torch.rand(1, 3, 32, 32)\n        >>> level_outputs = self.forward(inputs)\n        >>> for level_out in level_outputs:\n        ...     print(tuple(level_out.shape))\n        (1, 64, 8, 8)\n        (1, 128, 4, 4)\n        (1, 256, 2, 2)\n        (1, 512, 1, 1)\n    \"\"\"\n\n    arch_settings = {\n        18: (BasicBlock, (2, 2, 2, 2)),\n        34: (BasicBlock, (3, 4, 6, 3)),\n        50: (Bottleneck, (3, 4, 6, 3)),\n        101: (Bottleneck, (3, 4, 23, 3)),\n        152: (Bottleneck, (3, 8, 36, 3))\n    }\n\n    def __init__(self,\n                 depth,\n                 in_channels=3,\n                 stem_channels=64,\n                 base_channels=64,\n                 num_stages=4,\n                 strides=(1, 2, 2, 2),\n                 dilations=(1, 1, 1, 1),\n                 out_indices=(0, 1, 2, 3),\n                 style='pytorch',\n                 deep_stem=False,\n                 avg_down=False,\n                 frozen_stages=-1,\n                 conv_cfg=None,\n                 norm_cfg=dict(type='BN', requires_grad=True),\n                 norm_eval=False,\n                 dcn=None,\n                 stage_with_dcn=(False, False, False, False),\n                 plugins=None,\n                 multi_grid=None,\n                 contract_dilation=False,\n                 with_cp=False,\n                 zero_init_residual=True):\n        super(ResNet, self).__init__()\n        if depth not in self.arch_settings:\n            raise KeyError(f'invalid depth {depth} for resnet')\n        self.depth = depth\n        self.stem_channels = stem_channels\n        self.base_channels = base_channels\n        self.num_stages = num_stages\n        assert num_stages >= 1 and num_stages <= 4\n        self.strides = strides\n        self.dilations = dilations\n        assert len(strides) == len(dilations) == num_stages\n        self.out_indices = out_indices\n        assert max(out_indices) < num_stages\n        self.style = style\n        self.deep_stem = deep_stem\n        self.avg_down = avg_down\n        self.frozen_stages = frozen_stages\n        self.conv_cfg = conv_cfg\n        self.norm_cfg = norm_cfg\n        self.with_cp = with_cp\n        self.norm_eval = norm_eval\n        self.dcn = dcn\n        self.stage_with_dcn = stage_with_dcn\n        if dcn is not None:\n            assert len(stage_with_dcn) == num_stages\n        self.plugins = plugins\n        self.multi_grid = multi_grid\n        self.contract_dilation = contract_dilation\n        self.zero_init_residual = zero_init_residual\n        self.block, stage_blocks = self.arch_settings[depth]\n        self.stage_blocks = stage_blocks[:num_stages]\n        self.inplanes = stem_channels\n\n        self._make_stem_layer(in_channels, stem_channels)\n\n        self.res_layers = []\n        for i, num_blocks in enumerate(self.stage_blocks):\n            stride = strides[i]\n            dilation = dilations[i]\n            dcn = self.dcn if self.stage_with_dcn[i] else None\n            if plugins is not None:\n                stage_plugins = self.make_stage_plugins(plugins, i)\n            else:\n                stage_plugins = None\n            # multi grid is applied to last layer only\n            stage_multi_grid = multi_grid if i == len(\n                self.stage_blocks) - 1 else None\n            planes = base_channels * 2**i\n            res_layer = self.make_res_layer(\n                block=self.block,\n                inplanes=self.inplanes,\n                planes=planes,\n                num_blocks=num_blocks,\n                stride=stride,\n                dilation=dilation,\n                style=self.style,\n                avg_down=self.avg_down,\n                with_cp=with_cp,\n                conv_cfg=conv_cfg,\n                norm_cfg=norm_cfg,\n                dcn=dcn,\n                plugins=stage_plugins,\n                multi_grid=stage_multi_grid,\n                contract_dilation=contract_dilation)\n            self.inplanes = planes * self.block.expansion\n            layer_name = f'layer{i+1}'\n            self.add_module(layer_name, res_layer)\n            self.res_layers.append(layer_name)\n\n        self._freeze_stages()\n\n        self.feat_dim = self.block.expansion * base_channels * 2**(\n            len(self.stage_blocks) - 1)\n\n    def make_stage_plugins(self, plugins, stage_idx):\n        \"\"\"make plugins for ResNet 'stage_idx'th stage .\n\n        Currently we support to insert 'context_block',\n        'empirical_attention_block', 'nonlocal_block' into the backbone like\n        ResNet/ResNeXt. They could be inserted after conv1/conv2/conv3 of\n        Bottleneck.\n\n        An example of plugins format could be :\n        >>> plugins=[\n        ...     dict(cfg=dict(type='xxx', arg1='xxx'),\n        ...          stages=(False, True, True, True),\n        ...          position='after_conv2'),\n        ...     dict(cfg=dict(type='yyy'),\n        ...          stages=(True, True, True, True),\n        ...          position='after_conv3'),\n        ...     dict(cfg=dict(type='zzz', postfix='1'),\n        ...          stages=(True, True, True, True),\n        ...          position='after_conv3'),\n        ...     dict(cfg=dict(type='zzz', postfix='2'),\n        ...          stages=(True, True, True, True),\n        ...          position='after_conv3')\n        ... ]\n        >>> self = ResNet(depth=18)\n        >>> stage_plugins = self.make_stage_plugins(plugins, 0)\n        >>> assert len(stage_plugins) == 3\n\n        Suppose 'stage_idx=0', the structure of blocks in the stage would be:\n            conv1-> conv2->conv3->yyy->zzz1->zzz2\n        Suppose 'stage_idx=1', the structure of blocks in the stage would be:\n            conv1-> conv2->xxx->conv3->yyy->zzz1->zzz2\n\n        If stages is missing, the plugin would be applied to all stages.\n\n        Args:\n            plugins (list[dict]): List of plugins cfg to build. The postfix is\n                required if multiple same type plugins are inserted.\n            stage_idx (int): Index of stage to build\n\n        Returns:\n            list[dict]: Plugins for current stage\n        \"\"\"\n        stage_plugins = []\n        for plugin in plugins:\n            plugin = plugin.copy()\n            stages = plugin.pop('stages', None)\n            assert stages is None or len(stages) == self.num_stages\n            # whether to insert plugin into current stage\n            if stages is None or stages[stage_idx]:\n                stage_plugins.append(plugin)\n\n        return stage_plugins\n\n    def make_res_layer(self, **kwargs):\n        \"\"\"Pack all blocks in a stage into a ``ResLayer``.\"\"\"\n        return ResLayer(**kwargs)\n\n    @property\n    def norm1(self):\n        \"\"\"nn.Module: the normalization layer named \"norm1\" \"\"\"\n        return getattr(self, self.norm1_name)\n\n    def _make_stem_layer(self, in_channels, stem_channels):\n        \"\"\"Make stem layer for ResNet.\"\"\"\n        if self.deep_stem:\n            self.stem = nn.Sequential(\n                build_conv_layer(\n                    self.conv_cfg,\n                    in_channels,\n                    stem_channels // 2,\n                    kernel_size=3,\n                    stride=2,\n                    padding=1,\n                    bias=False),\n                build_norm_layer(self.norm_cfg, stem_channels // 2)[1],\n                nn.ReLU(inplace=True),\n                build_conv_layer(\n                    self.conv_cfg,\n                    stem_channels // 2,\n                    stem_channels // 2,\n                    kernel_size=3,\n                    stride=1,\n                    padding=1,\n                    bias=False),\n                build_norm_layer(self.norm_cfg, stem_channels // 2)[1],\n                nn.ReLU(inplace=True),\n                build_conv_layer(\n                    self.conv_cfg,\n                    stem_channels // 2,\n                    stem_channels,\n                    kernel_size=3,\n                    stride=1,\n                    padding=1,\n                    bias=False),\n                build_norm_layer(self.norm_cfg, stem_channels)[1],\n                nn.ReLU(inplace=True))\n        else:\n            self.conv1 = build_conv_layer(\n                self.conv_cfg,\n                in_channels,\n                stem_channels,\n                kernel_size=7,\n                stride=2,\n                padding=3,\n                bias=False)\n            self.norm1_name, norm1 = build_norm_layer(\n                self.norm_cfg, stem_channels, postfix=1)\n            self.add_module(self.norm1_name, norm1)\n            self.relu = nn.ReLU(inplace=True)\n        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)\n\n    def _freeze_stages(self):\n        \"\"\"Freeze stages param and norm stats.\"\"\"\n        if self.frozen_stages >= 0:\n            if self.deep_stem:\n                self.stem.eval()\n                for param in self.stem.parameters():\n                    param.requires_grad = False\n            else:\n                self.norm1.eval()\n                for m in [self.conv1, self.norm1]:\n                    for param in m.parameters():\n                        param.requires_grad = False\n\n        for i in range(1, self.frozen_stages + 1):\n            m = getattr(self, f'layer{i}')\n            m.eval()\n            for param in m.parameters():\n                param.requires_grad = False\n\n    def init_weights(self, pretrained=None):\n        \"\"\"Initialize the weights in backbone.\n\n        Args:\n            pretrained (str, optional): Path to pre-trained weights.\n                Defaults to None.\n        \"\"\"\n        if isinstance(pretrained, str):\n            logger = get_root_logger()\n            load_checkpoint(self, pretrained, strict=False, logger=logger)\n        elif pretrained is None:\n            for m in self.modules():\n                if isinstance(m, nn.Conv2d):\n                    kaiming_init(m)\n                elif isinstance(m, (_BatchNorm, nn.GroupNorm)):\n                    constant_init(m, 1)\n\n            if self.dcn is not None:\n                for m in self.modules():\n                    if isinstance(m, Bottleneck) and hasattr(\n                            m, 'conv2_offset'):\n                        constant_init(m.conv2_offset, 0)\n\n            if self.zero_init_residual:\n                for m in self.modules():\n                    if isinstance(m, Bottleneck):\n                        constant_init(m.norm3, 0)\n                    elif isinstance(m, BasicBlock):\n                        constant_init(m.norm2, 0)\n        else:\n            raise TypeError('pretrained must be a str or None')\n\n    def forward(self, x):\n        \"\"\"Forward function.\"\"\"\n        if self.deep_stem:\n            x = self.stem(x)\n        else:\n            x = self.conv1(x)\n            x = self.norm1(x)\n            x = self.relu(x)\n        x = self.maxpool(x)\n        outs = []\n        for i, layer_name in enumerate(self.res_layers):\n            res_layer = getattr(self, layer_name)\n            x = res_layer(x)\n            if i in self.out_indices:\n                outs.append(x)\n        return tuple(outs)\n\n    def train(self, mode=True):\n        \"\"\"Convert the model into training mode while keep normalization layer\n        freezed.\"\"\"\n        super(ResNet, self).train(mode)\n        self._freeze_stages()\n        if mode and self.norm_eval:\n            for m in self.modules():\n                # trick: eval have effect on BatchNorm only\n                if isinstance(m, _BatchNorm):\n                    m.eval()\n\n\n@BACKBONES.register_module()\nclass ResNetV1c(ResNet):\n    \"\"\"ResNetV1c variant described in [1]_.\n\n    Compared with default ResNet(ResNetV1b), ResNetV1c replaces the 7x7 conv\n    in the input stem with three 3x3 convs.\n\n    References:\n        .. [1] https://arxiv.org/pdf/1812.01187.pdf\n    \"\"\"\n\n    def __init__(self, **kwargs):\n        super(ResNetV1c, self).__init__(\n            deep_stem=True, avg_down=False, **kwargs)\n\n\n@BACKBONES.register_module()\nclass ResNetV1d(ResNet):\n    \"\"\"ResNetV1d variant described in [1]_.\n\n    Compared with default ResNet(ResNetV1b), ResNetV1d replaces the 7x7 conv in\n    the input stem with three 3x3 convs. And in the downsampling block, a 2x2\n    avg_pool with stride 2 is added before conv, whose stride is changed to 1.\n    \"\"\"\n\n    def __init__(self, **kwargs):\n        super(ResNetV1d, self).__init__(\n            deep_stem=True, avg_down=True, **kwargs)\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/mmseg/models/backbones/resnext.py",
    "content": "import math\n\nfrom mmcv.cnn import build_conv_layer, build_norm_layer\n\nfrom ..builder import BACKBONES\nfrom ..utils import ResLayer\nfrom .resnet import Bottleneck as _Bottleneck\nfrom .resnet import ResNet\n\n\nclass Bottleneck(_Bottleneck):\n    \"\"\"Bottleneck block for ResNeXt.\n\n    If style is \"pytorch\", the stride-two layer is the 3x3 conv layer, if it is\n    \"caffe\", the stride-two layer is the first 1x1 conv layer.\n    \"\"\"\n\n    def __init__(self,\n                 inplanes,\n                 planes,\n                 groups=1,\n                 base_width=4,\n                 base_channels=64,\n                 **kwargs):\n        super(Bottleneck, self).__init__(inplanes, planes, **kwargs)\n\n        if groups == 1:\n            width = self.planes\n        else:\n            width = math.floor(self.planes *\n                               (base_width / base_channels)) * groups\n\n        self.norm1_name, norm1 = build_norm_layer(\n            self.norm_cfg, width, postfix=1)\n        self.norm2_name, norm2 = build_norm_layer(\n            self.norm_cfg, width, postfix=2)\n        self.norm3_name, norm3 = build_norm_layer(\n            self.norm_cfg, self.planes * self.expansion, postfix=3)\n\n        self.conv1 = build_conv_layer(\n            self.conv_cfg,\n            self.inplanes,\n            width,\n            kernel_size=1,\n            stride=self.conv1_stride,\n            bias=False)\n        self.add_module(self.norm1_name, norm1)\n        fallback_on_stride = False\n        self.with_modulated_dcn = False\n        if self.with_dcn:\n            fallback_on_stride = self.dcn.pop('fallback_on_stride', False)\n        if not self.with_dcn or fallback_on_stride:\n            self.conv2 = build_conv_layer(\n                self.conv_cfg,\n                width,\n                width,\n                kernel_size=3,\n                stride=self.conv2_stride,\n                padding=self.dilation,\n                dilation=self.dilation,\n                groups=groups,\n                bias=False)\n        else:\n            assert self.conv_cfg is None, 'conv_cfg must be None for DCN'\n            self.conv2 = build_conv_layer(\n                self.dcn,\n                width,\n                width,\n                kernel_size=3,\n                stride=self.conv2_stride,\n                padding=self.dilation,\n                dilation=self.dilation,\n                groups=groups,\n                bias=False)\n\n        self.add_module(self.norm2_name, norm2)\n        self.conv3 = build_conv_layer(\n            self.conv_cfg,\n            width,\n            self.planes * self.expansion,\n            kernel_size=1,\n            bias=False)\n        self.add_module(self.norm3_name, norm3)\n\n\n@BACKBONES.register_module()\nclass ResNeXt(ResNet):\n    \"\"\"ResNeXt backbone.\n\n    Args:\n        depth (int): Depth of resnet, from {18, 34, 50, 101, 152}.\n        in_channels (int): Number of input image channels. Normally 3.\n        num_stages (int): Resnet stages, normally 4.\n        groups (int): Group of resnext.\n        base_width (int): Base width of resnext.\n        strides (Sequence[int]): Strides of the first block of each stage.\n        dilations (Sequence[int]): Dilation of each stage.\n        out_indices (Sequence[int]): Output from which stages.\n        style (str): `pytorch` or `caffe`. If set to \"pytorch\", the stride-two\n            layer is the 3x3 conv layer, otherwise the stride-two layer is\n            the first 1x1 conv layer.\n        frozen_stages (int): Stages to be frozen (all param fixed). -1 means\n            not freezing any parameters.\n        norm_cfg (dict): dictionary to construct and config norm layer.\n        norm_eval (bool): Whether to set norm layers to eval mode, namely,\n            freeze running stats (mean and var). Note: Effect on Batch Norm\n            and its variants only.\n        with_cp (bool): Use checkpoint or not. Using checkpoint will save some\n            memory while slowing down the training speed.\n        zero_init_residual (bool): whether to use zero init for last norm layer\n            in resblocks to let them behave as identity.\n\n    Example:\n        >>> from mmseg.models import ResNeXt\n        >>> import torch\n        >>> self = ResNeXt(depth=50)\n        >>> self.eval()\n        >>> inputs = torch.rand(1, 3, 32, 32)\n        >>> level_outputs = self.forward(inputs)\n        >>> for level_out in level_outputs:\n        ...     print(tuple(level_out.shape))\n        (1, 256, 8, 8)\n        (1, 512, 4, 4)\n        (1, 1024, 2, 2)\n        (1, 2048, 1, 1)\n    \"\"\"\n\n    arch_settings = {\n        50: (Bottleneck, (3, 4, 6, 3)),\n        101: (Bottleneck, (3, 4, 23, 3)),\n        152: (Bottleneck, (3, 8, 36, 3))\n    }\n\n    def __init__(self, groups=1, base_width=4, **kwargs):\n        self.groups = groups\n        self.base_width = base_width\n        super(ResNeXt, self).__init__(**kwargs)\n\n    def make_res_layer(self, **kwargs):\n        \"\"\"Pack all blocks in a stage into a ``ResLayer``\"\"\"\n        return ResLayer(\n            groups=self.groups,\n            base_width=self.base_width,\n            base_channels=self.base_channels,\n            **kwargs)\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/mmseg/models/backbones/unet.py",
    "content": "import torch.nn as nn\nimport torch.utils.checkpoint as cp\nfrom mmcv.cnn import (UPSAMPLE_LAYERS, ConvModule, build_activation_layer,\n                      build_norm_layer, constant_init, kaiming_init)\nfrom mmcv.runner import load_checkpoint\nfrom mmcv.utils.parrots_wrapper import _BatchNorm\n\nfrom mmseg.utils import get_root_logger\nfrom ..builder import BACKBONES\nfrom ..utils import UpConvBlock\n\n\nclass BasicConvBlock(nn.Module):\n    \"\"\"Basic convolutional block for UNet.\n\n    This module consists of several plain convolutional layers.\n\n    Args:\n        in_channels (int): Number of input channels.\n        out_channels (int): Number of output channels.\n        num_convs (int): Number of convolutional layers. Default: 2.\n        stride (int): Whether use stride convolution to downsample\n            the input feature map. If stride=2, it only uses stride convolution\n            in the first convolutional layer to downsample the input feature\n            map. Options are 1 or 2. Default: 1.\n        dilation (int): Whether use dilated convolution to expand the\n            receptive field. Set dilation rate of each convolutional layer and\n            the dilation rate of the first convolutional layer is always 1.\n            Default: 1.\n        with_cp (bool): Use checkpoint or not. Using checkpoint will save some\n            memory while slowing down the training speed. Default: False.\n        conv_cfg (dict | None): Config dict for convolution layer.\n            Default: None.\n        norm_cfg (dict | None): Config dict for normalization layer.\n            Default: dict(type='BN').\n        act_cfg (dict | None): Config dict for activation layer in ConvModule.\n            Default: dict(type='ReLU').\n        dcn (bool): Use deformable convoluton in convolutional layer or not.\n            Default: None.\n        plugins (dict): plugins for convolutional layers. Default: None.\n    \"\"\"\n\n    def __init__(self,\n                 in_channels,\n                 out_channels,\n                 num_convs=2,\n                 stride=1,\n                 dilation=1,\n                 with_cp=False,\n                 conv_cfg=None,\n                 norm_cfg=dict(type='BN'),\n                 act_cfg=dict(type='ReLU'),\n                 dcn=None,\n                 plugins=None):\n        super(BasicConvBlock, self).__init__()\n        assert dcn is None, 'Not implemented yet.'\n        assert plugins is None, 'Not implemented yet.'\n\n        self.with_cp = with_cp\n        convs = []\n        for i in range(num_convs):\n            convs.append(\n                ConvModule(\n                    in_channels=in_channels if i == 0 else out_channels,\n                    out_channels=out_channels,\n                    kernel_size=3,\n                    stride=stride if i == 0 else 1,\n                    dilation=1 if i == 0 else dilation,\n                    padding=1 if i == 0 else dilation,\n                    conv_cfg=conv_cfg,\n                    norm_cfg=norm_cfg,\n                    act_cfg=act_cfg))\n\n        self.convs = nn.Sequential(*convs)\n\n    def forward(self, x):\n        \"\"\"Forward function.\"\"\"\n\n        if self.with_cp and x.requires_grad:\n            out = cp.checkpoint(self.convs, x)\n        else:\n            out = self.convs(x)\n        return out\n\n\n@UPSAMPLE_LAYERS.register_module()\nclass DeconvModule(nn.Module):\n    \"\"\"Deconvolution upsample module in decoder for UNet (2X upsample).\n\n    This module uses deconvolution to upsample feature map in the decoder\n    of UNet.\n\n    Args:\n        in_channels (int): Number of input channels.\n        out_channels (int): Number of output channels.\n        with_cp (bool): Use checkpoint or not. Using checkpoint will save some\n            memory while slowing down the training speed. Default: False.\n        norm_cfg (dict | None): Config dict for normalization layer.\n            Default: dict(type='BN').\n        act_cfg (dict | None): Config dict for activation layer in ConvModule.\n            Default: dict(type='ReLU').\n        kernel_size (int): Kernel size of the convolutional layer. Default: 4.\n    \"\"\"\n\n    def __init__(self,\n                 in_channels,\n                 out_channels,\n                 with_cp=False,\n                 norm_cfg=dict(type='BN'),\n                 act_cfg=dict(type='ReLU'),\n                 *,\n                 kernel_size=4,\n                 scale_factor=2):\n        super(DeconvModule, self).__init__()\n\n        assert (kernel_size - scale_factor >= 0) and\\\n               (kernel_size - scale_factor) % 2 == 0,\\\n               f'kernel_size should be greater than or equal to scale_factor '\\\n               f'and (kernel_size - scale_factor) should be even numbers, '\\\n               f'while the kernel size is {kernel_size} and scale_factor is '\\\n               f'{scale_factor}.'\n\n        stride = scale_factor\n        padding = (kernel_size - scale_factor) // 2\n        self.with_cp = with_cp\n        deconv = nn.ConvTranspose2d(\n            in_channels,\n            out_channels,\n            kernel_size=kernel_size,\n            stride=stride,\n            padding=padding)\n\n        norm_name, norm = build_norm_layer(norm_cfg, out_channels)\n        activate = build_activation_layer(act_cfg)\n        self.deconv_upsamping = nn.Sequential(deconv, norm, activate)\n\n    def forward(self, x):\n        \"\"\"Forward function.\"\"\"\n\n        if self.with_cp and x.requires_grad:\n            out = cp.checkpoint(self.deconv_upsamping, x)\n        else:\n            out = self.deconv_upsamping(x)\n        return out\n\n\n@UPSAMPLE_LAYERS.register_module()\nclass InterpConv(nn.Module):\n    \"\"\"Interpolation upsample module in decoder for UNet.\n\n    This module uses interpolation to upsample feature map in the decoder\n    of UNet. It consists of one interpolation upsample layer and one\n    convolutional layer. It can be one interpolation upsample layer followed\n    by one convolutional layer (conv_first=False) or one convolutional layer\n    followed by one interpolation upsample layer (conv_first=True).\n\n    Args:\n        in_channels (int): Number of input channels.\n        out_channels (int): Number of output channels.\n        with_cp (bool): Use checkpoint or not. Using checkpoint will save some\n            memory while slowing down the training speed. Default: False.\n        norm_cfg (dict | None): Config dict for normalization layer.\n            Default: dict(type='BN').\n        act_cfg (dict | None): Config dict for activation layer in ConvModule.\n            Default: dict(type='ReLU').\n        conv_cfg (dict | None): Config dict for convolution layer.\n            Default: None.\n        conv_first (bool): Whether convolutional layer or interpolation\n            upsample layer first. Default: False. It means interpolation\n            upsample layer followed by one convolutional layer.\n        kernel_size (int): Kernel size of the convolutional layer. Default: 1.\n        stride (int): Stride of the convolutional layer. Default: 1.\n        padding (int): Padding of the convolutional layer. Default: 1.\n        upsampe_cfg (dict): Interpolation config of the upsample layer.\n            Default: dict(\n                scale_factor=2, mode='bilinear', align_corners=False).\n    \"\"\"\n\n    def __init__(self,\n                 in_channels,\n                 out_channels,\n                 with_cp=False,\n                 norm_cfg=dict(type='BN'),\n                 act_cfg=dict(type='ReLU'),\n                 *,\n                 conv_cfg=None,\n                 conv_first=False,\n                 kernel_size=1,\n                 stride=1,\n                 padding=0,\n                 upsampe_cfg=dict(\n                     scale_factor=2, mode='bilinear', align_corners=False)):\n        super(InterpConv, self).__init__()\n\n        self.with_cp = with_cp\n        conv = ConvModule(\n            in_channels,\n            out_channels,\n            kernel_size=kernel_size,\n            stride=stride,\n            padding=padding,\n            conv_cfg=conv_cfg,\n            norm_cfg=norm_cfg,\n            act_cfg=act_cfg)\n        upsample = nn.Upsample(**upsampe_cfg)\n        if conv_first:\n            self.interp_upsample = nn.Sequential(conv, upsample)\n        else:\n            self.interp_upsample = nn.Sequential(upsample, conv)\n\n    def forward(self, x):\n        \"\"\"Forward function.\"\"\"\n\n        if self.with_cp and x.requires_grad:\n            out = cp.checkpoint(self.interp_upsample, x)\n        else:\n            out = self.interp_upsample(x)\n        return out\n\n\n@BACKBONES.register_module()\nclass UNet(nn.Module):\n    \"\"\"UNet backbone.\n    U-Net: Convolutional Networks for Biomedical Image Segmentation.\n    https://arxiv.org/pdf/1505.04597.pdf\n\n    Args:\n        in_channels (int): Number of input image channels. Default\" 3.\n        base_channels (int): Number of base channels of each stage.\n            The output channels of the first stage. Default: 64.\n        num_stages (int): Number of stages in encoder, normally 5. Default: 5.\n        strides (Sequence[int 1 | 2]): Strides of each stage in encoder.\n            len(strides) is equal to num_stages. Normally the stride of the\n            first stage in encoder is 1. If strides[i]=2, it uses stride\n            convolution to downsample in the correspondance encoder stage.\n            Default: (1, 1, 1, 1, 1).\n        enc_num_convs (Sequence[int]): Number of convolutional layers in the\n            convolution block of the correspondance encoder stage.\n            Default: (2, 2, 2, 2, 2).\n        dec_num_convs (Sequence[int]): Number of convolutional layers in the\n            convolution block of the correspondance decoder stage.\n            Default: (2, 2, 2, 2).\n        downsamples (Sequence[int]): Whether use MaxPool to downsample the\n            feature map after the first stage of encoder\n            (stages: [1, num_stages)). If the correspondance encoder stage use\n            stride convolution (strides[i]=2), it will never use MaxPool to\n            downsample, even downsamples[i-1]=True.\n            Default: (True, True, True, True).\n        enc_dilations (Sequence[int]): Dilation rate of each stage in encoder.\n            Default: (1, 1, 1, 1, 1).\n        dec_dilations (Sequence[int]): Dilation rate of each stage in decoder.\n            Default: (1, 1, 1, 1).\n        with_cp (bool): Use checkpoint or not. Using checkpoint will save some\n            memory while slowing down the training speed. Default: False.\n        conv_cfg (dict | None): Config dict for convolution layer.\n            Default: None.\n        norm_cfg (dict | None): Config dict for normalization layer.\n            Default: dict(type='BN').\n        act_cfg (dict | None): Config dict for activation layer in ConvModule.\n            Default: dict(type='ReLU').\n        upsample_cfg (dict): The upsample config of the upsample module in\n            decoder. Default: dict(type='InterpConv').\n        norm_eval (bool): Whether to set norm layers to eval mode, namely,\n            freeze running stats (mean and var). Note: Effect on Batch Norm\n            and its variants only. Default: False.\n        dcn (bool): Use deformable convoluton in convolutional layer or not.\n            Default: None.\n        plugins (dict): plugins for convolutional layers. Default: None.\n\n    Notice:\n        The input image size should be devisible by the whole downsample rate\n        of the encoder. More detail of the whole downsample rate can be found\n        in UNet._check_input_devisible.\n\n    \"\"\"\n\n    def __init__(self,\n                 in_channels=3,\n                 base_channels=64,\n                 num_stages=5,\n                 strides=(1, 1, 1, 1, 1),\n                 enc_num_convs=(2, 2, 2, 2, 2),\n                 dec_num_convs=(2, 2, 2, 2),\n                 downsamples=(True, True, True, True),\n                 enc_dilations=(1, 1, 1, 1, 1),\n                 dec_dilations=(1, 1, 1, 1),\n                 with_cp=False,\n                 conv_cfg=None,\n                 norm_cfg=dict(type='BN'),\n                 act_cfg=dict(type='ReLU'),\n                 upsample_cfg=dict(type='InterpConv'),\n                 norm_eval=False,\n                 dcn=None,\n                 plugins=None):\n        super(UNet, self).__init__()\n        assert dcn is None, 'Not implemented yet.'\n        assert plugins is None, 'Not implemented yet.'\n        assert len(strides) == num_stages, \\\n            'The length of strides should be equal to num_stages, '\\\n            f'while the strides is {strides}, the length of '\\\n            f'strides is {len(strides)}, and the num_stages is '\\\n            f'{num_stages}.'\n        assert len(enc_num_convs) == num_stages, \\\n            'The length of enc_num_convs should be equal to num_stages, '\\\n            f'while the enc_num_convs is {enc_num_convs}, the length of '\\\n            f'enc_num_convs is {len(enc_num_convs)}, and the num_stages is '\\\n            f'{num_stages}.'\n        assert len(dec_num_convs) == (num_stages-1), \\\n            'The length of dec_num_convs should be equal to (num_stages-1), '\\\n            f'while the dec_num_convs is {dec_num_convs}, the length of '\\\n            f'dec_num_convs is {len(dec_num_convs)}, and the num_stages is '\\\n            f'{num_stages}.'\n        assert len(downsamples) == (num_stages-1), \\\n            'The length of downsamples should be equal to (num_stages-1), '\\\n            f'while the downsamples is {downsamples}, the length of '\\\n            f'downsamples is {len(downsamples)}, and the num_stages is '\\\n            f'{num_stages}.'\n        assert len(enc_dilations) == num_stages, \\\n            'The length of enc_dilations should be equal to num_stages, '\\\n            f'while the enc_dilations is {enc_dilations}, the length of '\\\n            f'enc_dilations is {len(enc_dilations)}, and the num_stages is '\\\n            f'{num_stages}.'\n        assert len(dec_dilations) == (num_stages-1), \\\n            'The length of dec_dilations should be equal to (num_stages-1), '\\\n            f'while the dec_dilations is {dec_dilations}, the length of '\\\n            f'dec_dilations is {len(dec_dilations)}, and the num_stages is '\\\n            f'{num_stages}.'\n        self.num_stages = num_stages\n        self.strides = strides\n        self.downsamples = downsamples\n        self.norm_eval = norm_eval\n\n        self.encoder = nn.ModuleList()\n        self.decoder = nn.ModuleList()\n\n        for i in range(num_stages):\n            enc_conv_block = []\n            if i != 0:\n                if strides[i] == 1 and downsamples[i - 1]:\n                    enc_conv_block.append(nn.MaxPool2d(kernel_size=2))\n                upsample = (strides[i] != 1 or downsamples[i - 1])\n                self.decoder.append(\n                    UpConvBlock(\n                        conv_block=BasicConvBlock,\n                        in_channels=base_channels * 2**i,\n                        skip_channels=base_channels * 2**(i - 1),\n                        out_channels=base_channels * 2**(i - 1),\n                        num_convs=dec_num_convs[i - 1],\n                        stride=1,\n                        dilation=dec_dilations[i - 1],\n                        with_cp=with_cp,\n                        conv_cfg=conv_cfg,\n                        norm_cfg=norm_cfg,\n                        act_cfg=act_cfg,\n                        upsample_cfg=upsample_cfg if upsample else None,\n                        dcn=None,\n                        plugins=None))\n\n            enc_conv_block.append(\n                BasicConvBlock(\n                    in_channels=in_channels,\n                    out_channels=base_channels * 2**i,\n                    num_convs=enc_num_convs[i],\n                    stride=strides[i],\n                    dilation=enc_dilations[i],\n                    with_cp=with_cp,\n                    conv_cfg=conv_cfg,\n                    norm_cfg=norm_cfg,\n                    act_cfg=act_cfg,\n                    dcn=None,\n                    plugins=None))\n            self.encoder.append((nn.Sequential(*enc_conv_block)))\n            in_channels = base_channels * 2**i\n\n    def forward(self, x):\n        self._check_input_devisible(x)\n        enc_outs = []\n        for enc in self.encoder:\n            x = enc(x)\n            enc_outs.append(x)\n        dec_outs = [x]\n        for i in reversed(range(len(self.decoder))):\n            x = self.decoder[i](enc_outs[i], x)\n            dec_outs.append(x)\n\n        return dec_outs\n\n    def train(self, mode=True):\n        \"\"\"Convert the model into training mode while keep normalization layer\n        freezed.\"\"\"\n        super(UNet, self).train(mode)\n        if mode and self.norm_eval:\n            for m in self.modules():\n                # trick: eval have effect on BatchNorm only\n                if isinstance(m, _BatchNorm):\n                    m.eval()\n\n    def _check_input_devisible(self, x):\n        h, w = x.shape[-2:]\n        whole_downsample_rate = 1\n        for i in range(1, self.num_stages):\n            if self.strides[i] == 2 or self.downsamples[i - 1]:\n                whole_downsample_rate *= 2\n        assert (h % whole_downsample_rate == 0) \\\n            and (w % whole_downsample_rate == 0),\\\n            f'The input image size {(h, w)} should be devisible by the whole '\\\n            f'downsample rate {whole_downsample_rate}, when num_stages is '\\\n            f'{self.num_stages}, strides is {self.strides}, and downsamples '\\\n            f'is {self.downsamples}.'\n\n    def init_weights(self, pretrained=None):\n        \"\"\"Initialize the weights in backbone.\n\n        Args:\n            pretrained (str, optional): Path to pre-trained weights.\n                Defaults to None.\n        \"\"\"\n        if isinstance(pretrained, str):\n            logger = get_root_logger()\n            load_checkpoint(self, pretrained, strict=False, logger=logger)\n        elif pretrained is None:\n            for m in self.modules():\n                if isinstance(m, nn.Conv2d):\n                    kaiming_init(m)\n                elif isinstance(m, (_BatchNorm, nn.GroupNorm)):\n                    constant_init(m, 1)\n        else:\n            raise TypeError('pretrained must be a str or None')\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/mmseg/models/builder.py",
    "content": "import warnings\n\nfrom mmcv.utils import Registry, build_from_cfg\nfrom torch import nn\n\nBACKBONES = Registry('backbone')\nNECKS = Registry('neck')\nHEADS = Registry('head')\nLOSSES = Registry('loss')\nSEGMENTORS = Registry('segmentor')\n\n\ndef build(cfg, registry, default_args=None):\n    \"\"\"Build a module.\n\n    Args:\n        cfg (dict, list[dict]): The config of modules, is is either a dict\n            or a list of configs.\n        registry (:obj:`Registry`): A registry the module belongs to.\n        default_args (dict, optional): Default arguments to build the module.\n            Defaults to None.\n\n    Returns:\n        nn.Module: A built nn module.\n    \"\"\"\n\n    if isinstance(cfg, list):\n        modules = [\n            build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg\n        ]\n        return nn.Sequential(*modules)\n    else:\n        return build_from_cfg(cfg, registry, default_args)\n\n\ndef build_backbone(cfg):\n    \"\"\"Build backbone.\"\"\"\n    return build(cfg, BACKBONES)\n\n\ndef build_neck(cfg):\n    \"\"\"Build neck.\"\"\"\n    return build(cfg, NECKS)\n\n\ndef build_head(cfg):\n    \"\"\"Build head.\"\"\"\n    return build(cfg, HEADS)\n\n\ndef build_loss(cfg):\n    \"\"\"Build loss.\"\"\"\n    return build(cfg, LOSSES)\n\n\ndef build_segmentor(cfg, train_cfg=None, test_cfg=None):\n    \"\"\"Build segmentor.\"\"\"\n    if train_cfg is not None or test_cfg is not None:\n        warnings.warn(\n            'train_cfg and test_cfg is deprecated, '\n            'please specify them in model', UserWarning)\n    assert cfg.get('train_cfg') is None or train_cfg is None, \\\n        'train_cfg specified in both outer field and model field '\n    assert cfg.get('test_cfg') is None or test_cfg is None, \\\n        'test_cfg specified in both outer field and model field '\n    return build(cfg, SEGMENTORS, dict(train_cfg=train_cfg, test_cfg=test_cfg))\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/mmseg/models/decode_heads/__init__.py",
    "content": "from .ann_head import ANNHead\nfrom .apc_head import APCHead\nfrom .aspp_head import ASPPHead\nfrom .cc_head import CCHead\nfrom .da_head import DAHead\nfrom .dm_head import DMHead\nfrom .dnl_head import DNLHead\nfrom .ema_head import EMAHead\nfrom .enc_head import EncHead\nfrom .fcn_head import FCNHead\nfrom .fpn_head import FPNHead\nfrom .gc_head import GCHead\nfrom .lraspp_head import LRASPPHead\nfrom .nl_head import NLHead\nfrom .ocr_head import OCRHead\nfrom .point_head import PointHead\nfrom .psa_head import PSAHead\nfrom .psp_head import PSPHead\nfrom .sep_aspp_head import DepthwiseSeparableASPPHead\nfrom .sep_fcn_head import DepthwiseSeparableFCNHead\nfrom .uper_head import UPerHead\n\n__all__ = [\n    'FCNHead', 'PSPHead', 'ASPPHead', 'PSAHead', 'NLHead', 'GCHead', 'CCHead',\n    'UPerHead', 'DepthwiseSeparableASPPHead', 'ANNHead', 'DAHead', 'OCRHead',\n    'EncHead', 'DepthwiseSeparableFCNHead', 'FPNHead', 'EMAHead', 'DNLHead',\n    'PointHead', 'APCHead', 'DMHead', 'LRASPPHead'\n]\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/mmseg/models/decode_heads/ann_head.py",
    "content": "import torch\nimport torch.nn as nn\nfrom mmcv.cnn import ConvModule\n\nfrom ..builder import HEADS\nfrom ..utils import SelfAttentionBlock as _SelfAttentionBlock\nfrom .decode_head import BaseDecodeHead\n\n\nclass PPMConcat(nn.ModuleList):\n    \"\"\"Pyramid Pooling Module that only concat the features of each layer.\n\n    Args:\n        pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid\n            Module.\n    \"\"\"\n\n    def __init__(self, pool_scales=(1, 3, 6, 8)):\n        super(PPMConcat, self).__init__(\n            [nn.AdaptiveAvgPool2d(pool_scale) for pool_scale in pool_scales])\n\n    def forward(self, feats):\n        \"\"\"Forward function.\"\"\"\n        ppm_outs = []\n        for ppm in self:\n            ppm_out = ppm(feats)\n            ppm_outs.append(ppm_out.view(*feats.shape[:2], -1))\n        concat_outs = torch.cat(ppm_outs, dim=2)\n        return concat_outs\n\n\nclass SelfAttentionBlock(_SelfAttentionBlock):\n    \"\"\"Make a ANN used SelfAttentionBlock.\n\n    Args:\n        low_in_channels (int): Input channels of lower level feature,\n            which is the key feature for self-attention.\n        high_in_channels (int): Input channels of higher level feature,\n            which is the query feature for self-attention.\n        channels (int): Output channels of key/query transform.\n        out_channels (int): Output channels.\n        share_key_query (bool): Whether share projection weight between key\n            and query projection.\n        query_scale (int): The scale of query feature map.\n        key_pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid\n            Module of key feature.\n        conv_cfg (dict|None): Config of conv layers.\n        norm_cfg (dict|None): Config of norm layers.\n        act_cfg (dict|None): Config of activation layers.\n    \"\"\"\n\n    def __init__(self, low_in_channels, high_in_channels, channels,\n                 out_channels, share_key_query, query_scale, key_pool_scales,\n                 conv_cfg, norm_cfg, act_cfg):\n        key_psp = PPMConcat(key_pool_scales)\n        if query_scale > 1:\n            query_downsample = nn.MaxPool2d(kernel_size=query_scale)\n        else:\n            query_downsample = None\n        super(SelfAttentionBlock, self).__init__(\n            key_in_channels=low_in_channels,\n            query_in_channels=high_in_channels,\n            channels=channels,\n            out_channels=out_channels,\n            share_key_query=share_key_query,\n            query_downsample=query_downsample,\n            key_downsample=key_psp,\n            key_query_num_convs=1,\n            key_query_norm=True,\n            value_out_num_convs=1,\n            value_out_norm=False,\n            matmul_norm=True,\n            with_out=True,\n            conv_cfg=conv_cfg,\n            norm_cfg=norm_cfg,\n            act_cfg=act_cfg)\n\n\nclass AFNB(nn.Module):\n    \"\"\"Asymmetric Fusion Non-local Block(AFNB)\n\n    Args:\n        low_in_channels (int): Input channels of lower level feature,\n            which is the key feature for self-attention.\n        high_in_channels (int): Input channels of higher level feature,\n            which is the query feature for self-attention.\n        channels (int): Output channels of key/query transform.\n        out_channels (int): Output channels.\n            and query projection.\n        query_scales (tuple[int]): The scales of query feature map.\n            Default: (1,)\n        key_pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid\n            Module of key feature.\n        conv_cfg (dict|None): Config of conv layers.\n        norm_cfg (dict|None): Config of norm layers.\n        act_cfg (dict|None): Config of activation layers.\n    \"\"\"\n\n    def __init__(self, low_in_channels, high_in_channels, channels,\n                 out_channels, query_scales, key_pool_scales, conv_cfg,\n                 norm_cfg, act_cfg):\n        super(AFNB, self).__init__()\n        self.stages = nn.ModuleList()\n        for query_scale in query_scales:\n            self.stages.append(\n                SelfAttentionBlock(\n                    low_in_channels=low_in_channels,\n                    high_in_channels=high_in_channels,\n                    channels=channels,\n                    out_channels=out_channels,\n                    share_key_query=False,\n                    query_scale=query_scale,\n                    key_pool_scales=key_pool_scales,\n                    conv_cfg=conv_cfg,\n                    norm_cfg=norm_cfg,\n                    act_cfg=act_cfg))\n        self.bottleneck = ConvModule(\n            out_channels + high_in_channels,\n            out_channels,\n            1,\n            conv_cfg=conv_cfg,\n            norm_cfg=norm_cfg,\n            act_cfg=None)\n\n    def forward(self, low_feats, high_feats):\n        \"\"\"Forward function.\"\"\"\n        priors = [stage(high_feats, low_feats) for stage in self.stages]\n        context = torch.stack(priors, dim=0).sum(dim=0)\n        output = self.bottleneck(torch.cat([context, high_feats], 1))\n        return output\n\n\nclass APNB(nn.Module):\n    \"\"\"Asymmetric Pyramid Non-local Block (APNB)\n\n    Args:\n        in_channels (int): Input channels of key/query feature,\n            which is the key feature for self-attention.\n        channels (int): Output channels of key/query transform.\n        out_channels (int): Output channels.\n        query_scales (tuple[int]): The scales of query feature map.\n            Default: (1,)\n        key_pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid\n            Module of key feature.\n        conv_cfg (dict|None): Config of conv layers.\n        norm_cfg (dict|None): Config of norm layers.\n        act_cfg (dict|None): Config of activation layers.\n    \"\"\"\n\n    def __init__(self, in_channels, channels, out_channels, query_scales,\n                 key_pool_scales, conv_cfg, norm_cfg, act_cfg):\n        super(APNB, self).__init__()\n        self.stages = nn.ModuleList()\n        for query_scale in query_scales:\n            self.stages.append(\n                SelfAttentionBlock(\n                    low_in_channels=in_channels,\n                    high_in_channels=in_channels,\n                    channels=channels,\n                    out_channels=out_channels,\n                    share_key_query=True,\n                    query_scale=query_scale,\n                    key_pool_scales=key_pool_scales,\n                    conv_cfg=conv_cfg,\n                    norm_cfg=norm_cfg,\n                    act_cfg=act_cfg))\n        self.bottleneck = ConvModule(\n            2 * in_channels,\n            out_channels,\n            1,\n            conv_cfg=conv_cfg,\n            norm_cfg=norm_cfg,\n            act_cfg=act_cfg)\n\n    def forward(self, feats):\n        \"\"\"Forward function.\"\"\"\n        priors = [stage(feats, feats) for stage in self.stages]\n        context = torch.stack(priors, dim=0).sum(dim=0)\n        output = self.bottleneck(torch.cat([context, feats], 1))\n        return output\n\n\n@HEADS.register_module()\nclass ANNHead(BaseDecodeHead):\n    \"\"\"Asymmetric Non-local Neural Networks for Semantic Segmentation.\n\n    This head is the implementation of `ANNNet\n    <https://arxiv.org/abs/1908.07678>`_.\n\n    Args:\n        project_channels (int): Projection channels for Nonlocal.\n        query_scales (tuple[int]): The scales of query feature map.\n            Default: (1,)\n        key_pool_scales (tuple[int]): The pooling scales of key feature map.\n            Default: (1, 3, 6, 8).\n    \"\"\"\n\n    def __init__(self,\n                 project_channels,\n                 query_scales=(1, ),\n                 key_pool_scales=(1, 3, 6, 8),\n                 **kwargs):\n        super(ANNHead, self).__init__(\n            input_transform='multiple_select', **kwargs)\n        assert len(self.in_channels) == 2\n        low_in_channels, high_in_channels = self.in_channels\n        self.project_channels = project_channels\n        self.fusion = AFNB(\n            low_in_channels=low_in_channels,\n            high_in_channels=high_in_channels,\n            out_channels=high_in_channels,\n            channels=project_channels,\n            query_scales=query_scales,\n            key_pool_scales=key_pool_scales,\n            conv_cfg=self.conv_cfg,\n            norm_cfg=self.norm_cfg,\n            act_cfg=self.act_cfg)\n        self.bottleneck = ConvModule(\n            high_in_channels,\n            self.channels,\n            3,\n            padding=1,\n            conv_cfg=self.conv_cfg,\n            norm_cfg=self.norm_cfg,\n            act_cfg=self.act_cfg)\n        self.context = APNB(\n            in_channels=self.channels,\n            out_channels=self.channels,\n            channels=project_channels,\n            query_scales=query_scales,\n            key_pool_scales=key_pool_scales,\n            conv_cfg=self.conv_cfg,\n            norm_cfg=self.norm_cfg,\n            act_cfg=self.act_cfg)\n\n    def forward(self, inputs):\n        \"\"\"Forward function.\"\"\"\n        low_feats, high_feats = self._transform_inputs(inputs)\n        output = self.fusion(low_feats, high_feats)\n        output = self.dropout(output)\n        output = self.bottleneck(output)\n        output = self.context(output)\n        output = self.cls_seg(output)\n\n        return output\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/mmseg/models/decode_heads/apc_head.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom mmcv.cnn import ConvModule\n\nfrom mmseg.ops import resize\nfrom ..builder import HEADS\nfrom .decode_head import BaseDecodeHead\n\n\nclass ACM(nn.Module):\n    \"\"\"Adaptive Context Module used in APCNet.\n\n    Args:\n        pool_scale (int): Pooling scale used in Adaptive Context\n            Module to extract region fetures.\n        fusion (bool): Add one conv to fuse residual feature.\n        in_channels (int): Input channels.\n        channels (int): Channels after modules, before conv_seg.\n        conv_cfg (dict | None): Config of conv layers.\n        norm_cfg (dict | None): Config of norm layers.\n        act_cfg (dict): Config of activation layers.\n    \"\"\"\n\n    def __init__(self, pool_scale, fusion, in_channels, channels, conv_cfg,\n                 norm_cfg, act_cfg):\n        super(ACM, self).__init__()\n        self.pool_scale = pool_scale\n        self.fusion = fusion\n        self.in_channels = in_channels\n        self.channels = channels\n        self.conv_cfg = conv_cfg\n        self.norm_cfg = norm_cfg\n        self.act_cfg = act_cfg\n        self.pooled_redu_conv = ConvModule(\n            self.in_channels,\n            self.channels,\n            1,\n            conv_cfg=self.conv_cfg,\n            norm_cfg=self.norm_cfg,\n            act_cfg=self.act_cfg)\n\n        self.input_redu_conv = ConvModule(\n            self.in_channels,\n            self.channels,\n            1,\n            conv_cfg=self.conv_cfg,\n            norm_cfg=self.norm_cfg,\n            act_cfg=self.act_cfg)\n\n        self.global_info = ConvModule(\n            self.channels,\n            self.channels,\n            1,\n            conv_cfg=self.conv_cfg,\n            norm_cfg=self.norm_cfg,\n            act_cfg=self.act_cfg)\n\n        self.gla = nn.Conv2d(self.channels, self.pool_scale**2, 1, 1, 0)\n\n        self.residual_conv = ConvModule(\n            self.channels,\n            self.channels,\n            1,\n            conv_cfg=self.conv_cfg,\n            norm_cfg=self.norm_cfg,\n            act_cfg=self.act_cfg)\n\n        if self.fusion:\n            self.fusion_conv = ConvModule(\n                self.channels,\n                self.channels,\n                1,\n                conv_cfg=self.conv_cfg,\n                norm_cfg=self.norm_cfg,\n                act_cfg=self.act_cfg)\n\n    def forward(self, x):\n        \"\"\"Forward function.\"\"\"\n        pooled_x = F.adaptive_avg_pool2d(x, self.pool_scale)\n        # [batch_size, channels, h, w]\n        x = self.input_redu_conv(x)\n        # [batch_size, channels, pool_scale, pool_scale]\n        pooled_x = self.pooled_redu_conv(pooled_x)\n        batch_size = x.size(0)\n        # [batch_size, pool_scale * pool_scale, channels]\n        pooled_x = pooled_x.view(batch_size, self.channels,\n                                 -1).permute(0, 2, 1).contiguous()\n        # [batch_size, h * w, pool_scale * pool_scale]\n        affinity_matrix = self.gla(x + resize(\n            self.global_info(F.adaptive_avg_pool2d(x, 1)), size=x.shape[2:])\n                                   ).permute(0, 2, 3, 1).reshape(\n                                       batch_size, -1, self.pool_scale**2)\n        affinity_matrix = F.sigmoid(affinity_matrix)\n        # [batch_size, h * w, channels]\n        z_out = torch.matmul(affinity_matrix, pooled_x)\n        # [batch_size, channels, h * w]\n        z_out = z_out.permute(0, 2, 1).contiguous()\n        # [batch_size, channels, h, w]\n        z_out = z_out.view(batch_size, self.channels, x.size(2), x.size(3))\n        z_out = self.residual_conv(z_out)\n        z_out = F.relu(z_out + x)\n        if self.fusion:\n            z_out = self.fusion_conv(z_out)\n\n        return z_out\n\n\n@HEADS.register_module()\nclass APCHead(BaseDecodeHead):\n    \"\"\"Adaptive Pyramid Context Network for Semantic Segmentation.\n\n    This head is the implementation of\n    `APCNet <https://openaccess.thecvf.com/content_CVPR_2019/papers/\\\n    He_Adaptive_Pyramid_Context_Network_for_Semantic_Segmentation_\\\n    CVPR_2019_paper.pdf>`_.\n\n    Args:\n        pool_scales (tuple[int]): Pooling scales used in Adaptive Context\n            Module. Default: (1, 2, 3, 6).\n        fusion (bool): Add one conv to fuse residual feature.\n    \"\"\"\n\n    def __init__(self, pool_scales=(1, 2, 3, 6), fusion=True, **kwargs):\n        super(APCHead, self).__init__(**kwargs)\n        assert isinstance(pool_scales, (list, tuple))\n        self.pool_scales = pool_scales\n        self.fusion = fusion\n        acm_modules = []\n        for pool_scale in self.pool_scales:\n            acm_modules.append(\n                ACM(pool_scale,\n                    self.fusion,\n                    self.in_channels,\n                    self.channels,\n                    conv_cfg=self.conv_cfg,\n                    norm_cfg=self.norm_cfg,\n                    act_cfg=self.act_cfg))\n        self.acm_modules = nn.ModuleList(acm_modules)\n        self.bottleneck = ConvModule(\n            self.in_channels + len(pool_scales) * self.channels,\n            self.channels,\n            3,\n            padding=1,\n            conv_cfg=self.conv_cfg,\n            norm_cfg=self.norm_cfg,\n            act_cfg=self.act_cfg)\n\n    def forward(self, inputs):\n        \"\"\"Forward function.\"\"\"\n        x = self._transform_inputs(inputs)\n        acm_outs = [x]\n        for acm_module in self.acm_modules:\n            acm_outs.append(acm_module(x))\n        acm_outs = torch.cat(acm_outs, dim=1)\n        output = self.bottleneck(acm_outs)\n        output = self.cls_seg(output)\n        return output\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/mmseg/models/decode_heads/aspp_head.py",
    "content": "import torch\nimport torch.nn as nn\nfrom mmcv.cnn import ConvModule\n\nfrom mmseg.ops import resize\nfrom ..builder import HEADS\nfrom .decode_head import BaseDecodeHead\n\n\nclass ASPPModule(nn.ModuleList):\n    \"\"\"Atrous Spatial Pyramid Pooling (ASPP) Module.\n\n    Args:\n        dilations (tuple[int]): Dilation rate of each layer.\n        in_channels (int): Input channels.\n        channels (int): Channels after modules, before conv_seg.\n        conv_cfg (dict|None): Config of conv layers.\n        norm_cfg (dict|None): Config of norm layers.\n        act_cfg (dict): Config of activation layers.\n    \"\"\"\n\n    def __init__(self, dilations, in_channels, channels, conv_cfg, norm_cfg,\n                 act_cfg):\n        super(ASPPModule, self).__init__()\n        self.dilations = dilations\n        self.in_channels = in_channels\n        self.channels = channels\n        self.conv_cfg = conv_cfg\n        self.norm_cfg = norm_cfg\n        self.act_cfg = act_cfg\n        for dilation in dilations:\n            self.append(\n                ConvModule(\n                    self.in_channels,\n                    self.channels,\n                    1 if dilation == 1 else 3,\n                    dilation=dilation,\n                    padding=0 if dilation == 1 else dilation,\n                    conv_cfg=self.conv_cfg,\n                    norm_cfg=self.norm_cfg,\n                    act_cfg=self.act_cfg))\n\n    def forward(self, x):\n        \"\"\"Forward function.\"\"\"\n        aspp_outs = []\n        for aspp_module in self:\n            aspp_outs.append(aspp_module(x))\n\n        return aspp_outs\n\n\n@HEADS.register_module()\nclass ASPPHead(BaseDecodeHead):\n    \"\"\"Rethinking Atrous Convolution for Semantic Image Segmentation.\n\n    This head is the implementation of `DeepLabV3\n    <https://arxiv.org/abs/1706.05587>`_.\n\n    Args:\n        dilations (tuple[int]): Dilation rates for ASPP module.\n            Default: (1, 6, 12, 18).\n    \"\"\"\n\n    def __init__(self, dilations=(1, 6, 12, 18), **kwargs):\n        super(ASPPHead, self).__init__(**kwargs)\n        assert isinstance(dilations, (list, tuple))\n        self.dilations = dilations\n        self.image_pool = nn.Sequential(\n            nn.AdaptiveAvgPool2d(1),\n            ConvModule(\n                self.in_channels,\n                self.channels,\n                1,\n                conv_cfg=self.conv_cfg,\n                norm_cfg=self.norm_cfg,\n                act_cfg=self.act_cfg))\n        self.aspp_modules = ASPPModule(\n            dilations,\n            self.in_channels,\n            self.channels,\n            conv_cfg=self.conv_cfg,\n            norm_cfg=self.norm_cfg,\n            act_cfg=self.act_cfg)\n        self.bottleneck = ConvModule(\n            (len(dilations) + 1) * self.channels,\n            self.channels,\n            3,\n            padding=1,\n            conv_cfg=self.conv_cfg,\n            norm_cfg=self.norm_cfg,\n            act_cfg=self.act_cfg)\n\n    def forward(self, inputs):\n        \"\"\"Forward function.\"\"\"\n        x = self._transform_inputs(inputs)\n        aspp_outs = [\n            resize(\n                self.image_pool(x),\n                size=x.size()[2:],\n                mode='bilinear',\n                align_corners=self.align_corners)\n        ]\n        aspp_outs.extend(self.aspp_modules(x))\n        aspp_outs = torch.cat(aspp_outs, dim=1)\n        output = self.bottleneck(aspp_outs)\n        output = self.cls_seg(output)\n        return output\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/mmseg/models/decode_heads/cascade_decode_head.py",
    "content": "from abc import ABCMeta, abstractmethod\n\nfrom .decode_head import BaseDecodeHead\n\n\nclass BaseCascadeDecodeHead(BaseDecodeHead, metaclass=ABCMeta):\n    \"\"\"Base class for cascade decode head used in\n    :class:`CascadeEncoderDecoder.\"\"\"\n\n    def __init__(self, *args, **kwargs):\n        super(BaseCascadeDecodeHead, self).__init__(*args, **kwargs)\n\n    @abstractmethod\n    def forward(self, inputs, prev_output):\n        \"\"\"Placeholder of forward function.\"\"\"\n        pass\n\n    def forward_train(self, inputs, prev_output, img_metas, gt_semantic_seg,\n                      train_cfg):\n        \"\"\"Forward function for training.\n        Args:\n            inputs (list[Tensor]): List of multi-level img features.\n            prev_output (Tensor): The output of previous decode head.\n            img_metas (list[dict]): List of image info dict where each dict\n                has: 'img_shape', 'scale_factor', 'flip', and may also contain\n                'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.\n                For details on the values of these keys see\n                `mmseg/datasets/pipelines/formatting.py:Collect`.\n            gt_semantic_seg (Tensor): Semantic segmentation masks\n                used if the architecture supports semantic segmentation task.\n            train_cfg (dict): The training config.\n\n        Returns:\n            dict[str, Tensor]: a dictionary of loss components\n        \"\"\"\n        seg_logits = self.forward(inputs, prev_output)\n        losses = self.losses(seg_logits, gt_semantic_seg)\n\n        return losses\n\n    def forward_test(self, inputs, prev_output, img_metas, test_cfg):\n        \"\"\"Forward function for testing.\n\n        Args:\n            inputs (list[Tensor]): List of multi-level img features.\n            prev_output (Tensor): The output of previous decode head.\n            img_metas (list[dict]): List of image info dict where each dict\n                has: 'img_shape', 'scale_factor', 'flip', and may also contain\n                'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.\n                For details on the values of these keys see\n                `mmseg/datasets/pipelines/formatting.py:Collect`.\n            test_cfg (dict): The testing config.\n\n        Returns:\n            Tensor: Output segmentation map.\n        \"\"\"\n        return self.forward(inputs, prev_output)\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/mmseg/models/decode_heads/cc_head.py",
    "content": "import torch\n\nfrom ..builder import HEADS\nfrom .fcn_head import FCNHead\n\ntry:\n    from mmcv.ops import CrissCrossAttention\nexcept ModuleNotFoundError:\n    CrissCrossAttention = None\n\n\n@HEADS.register_module()\nclass CCHead(FCNHead):\n    \"\"\"CCNet: Criss-Cross Attention for Semantic Segmentation.\n\n    This head is the implementation of `CCNet\n    <https://arxiv.org/abs/1811.11721>`_.\n\n    Args:\n        recurrence (int): Number of recurrence of Criss Cross Attention\n            module. Default: 2.\n    \"\"\"\n\n    def __init__(self, recurrence=2, **kwargs):\n        if CrissCrossAttention is None:\n            raise RuntimeError('Please install mmcv-full for '\n                               'CrissCrossAttention ops')\n        super(CCHead, self).__init__(num_convs=2, **kwargs)\n        self.recurrence = recurrence\n        self.cca = CrissCrossAttention(self.channels)\n\n    def forward(self, inputs):\n        \"\"\"Forward function.\"\"\"\n        x = self._transform_inputs(inputs)\n        output = self.convs[0](x)\n        for _ in range(self.recurrence):\n            output = self.cca(output)\n        output = self.convs[1](output)\n        if self.concat_input:\n            output = self.conv_cat(torch.cat([x, output], dim=1))\n        output = self.cls_seg(output)\n        return output\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/mmseg/models/decode_heads/da_head.py",
    "content": "import torch\nimport torch.nn.functional as F\nfrom mmcv.cnn import ConvModule, Scale\nfrom torch import nn\n\nfrom mmseg.core import add_prefix\nfrom ..builder import HEADS\nfrom ..utils import SelfAttentionBlock as _SelfAttentionBlock\nfrom .decode_head import BaseDecodeHead\n\n\nclass PAM(_SelfAttentionBlock):\n    \"\"\"Position Attention Module (PAM)\n\n    Args:\n        in_channels (int): Input channels of key/query feature.\n        channels (int): Output channels of key/query transform.\n    \"\"\"\n\n    def __init__(self, in_channels, channels):\n        super(PAM, self).__init__(\n            key_in_channels=in_channels,\n            query_in_channels=in_channels,\n            channels=channels,\n            out_channels=in_channels,\n            share_key_query=False,\n            query_downsample=None,\n            key_downsample=None,\n            key_query_num_convs=1,\n            key_query_norm=False,\n            value_out_num_convs=1,\n            value_out_norm=False,\n            matmul_norm=False,\n            with_out=False,\n            conv_cfg=None,\n            norm_cfg=None,\n            act_cfg=None)\n\n        self.gamma = Scale(0)\n\n    def forward(self, x):\n        \"\"\"Forward function.\"\"\"\n        out = super(PAM, self).forward(x, x)\n\n        out = self.gamma(out) + x\n        return out\n\n\nclass CAM(nn.Module):\n    \"\"\"Channel Attention Module (CAM)\"\"\"\n\n    def __init__(self):\n        super(CAM, self).__init__()\n        self.gamma = Scale(0)\n\n    def forward(self, x):\n        \"\"\"Forward function.\"\"\"\n        batch_size, channels, height, width = x.size()\n        proj_query = x.view(batch_size, channels, -1)\n        proj_key = x.view(batch_size, channels, -1).permute(0, 2, 1)\n        energy = torch.bmm(proj_query, proj_key)\n        energy_new = torch.max(\n            energy, -1, keepdim=True)[0].expand_as(energy) - energy\n        attention = F.softmax(energy_new, dim=-1)\n        proj_value = x.view(batch_size, channels, -1)\n\n        out = torch.bmm(attention, proj_value)\n        out = out.view(batch_size, channels, height, width)\n\n        out = self.gamma(out) + x\n        return out\n\n\n@HEADS.register_module()\nclass DAHead(BaseDecodeHead):\n    \"\"\"Dual Attention Network for Scene Segmentation.\n\n    This head is the implementation of `DANet\n    <https://arxiv.org/abs/1809.02983>`_.\n\n    Args:\n        pam_channels (int): The channels of Position Attention Module(PAM).\n    \"\"\"\n\n    def __init__(self, pam_channels, **kwargs):\n        super(DAHead, self).__init__(**kwargs)\n        self.pam_channels = pam_channels\n        self.pam_in_conv = ConvModule(\n            self.in_channels,\n            self.channels,\n            3,\n            padding=1,\n            conv_cfg=self.conv_cfg,\n            norm_cfg=self.norm_cfg,\n            act_cfg=self.act_cfg)\n        self.pam = PAM(self.channels, pam_channels)\n        self.pam_out_conv = ConvModule(\n            self.channels,\n            self.channels,\n            3,\n            padding=1,\n            conv_cfg=self.conv_cfg,\n            norm_cfg=self.norm_cfg,\n            act_cfg=self.act_cfg)\n        self.pam_conv_seg = nn.Conv2d(\n            self.channels, self.num_classes, kernel_size=1)\n\n        self.cam_in_conv = ConvModule(\n            self.in_channels,\n            self.channels,\n            3,\n            padding=1,\n            conv_cfg=self.conv_cfg,\n            norm_cfg=self.norm_cfg,\n            act_cfg=self.act_cfg)\n        self.cam = CAM()\n        self.cam_out_conv = ConvModule(\n            self.channels,\n            self.channels,\n            3,\n            padding=1,\n            conv_cfg=self.conv_cfg,\n            norm_cfg=self.norm_cfg,\n            act_cfg=self.act_cfg)\n        self.cam_conv_seg = nn.Conv2d(\n            self.channels, self.num_classes, kernel_size=1)\n\n    def pam_cls_seg(self, feat):\n        \"\"\"PAM feature classification.\"\"\"\n        if self.dropout is not None:\n            feat = self.dropout(feat)\n        output = self.pam_conv_seg(feat)\n        return output\n\n    def cam_cls_seg(self, feat):\n        \"\"\"CAM feature classification.\"\"\"\n        if self.dropout is not None:\n            feat = self.dropout(feat)\n        output = self.cam_conv_seg(feat)\n        return output\n\n    def forward(self, inputs):\n        \"\"\"Forward function.\"\"\"\n        x = self._transform_inputs(inputs)\n        pam_feat = self.pam_in_conv(x)\n        pam_feat = self.pam(pam_feat)\n        pam_feat = self.pam_out_conv(pam_feat)\n        pam_out = self.pam_cls_seg(pam_feat)\n\n        cam_feat = self.cam_in_conv(x)\n        cam_feat = self.cam(cam_feat)\n        cam_feat = self.cam_out_conv(cam_feat)\n        cam_out = self.cam_cls_seg(cam_feat)\n\n        feat_sum = pam_feat + cam_feat\n        pam_cam_out = self.cls_seg(feat_sum)\n\n        return pam_cam_out, pam_out, cam_out\n\n    def forward_test(self, inputs, img_metas, test_cfg):\n        \"\"\"Forward function for testing, only ``pam_cam`` is used.\"\"\"\n        return self.forward(inputs)[0]\n\n    def losses(self, seg_logit, seg_label):\n        \"\"\"Compute ``pam_cam``, ``pam``, ``cam`` loss.\"\"\"\n        pam_cam_seg_logit, pam_seg_logit, cam_seg_logit = seg_logit\n        loss = dict()\n        loss.update(\n            add_prefix(\n                super(DAHead, self).losses(pam_cam_seg_logit, seg_label),\n                'pam_cam'))\n        loss.update(\n            add_prefix(\n                super(DAHead, self).losses(pam_seg_logit, seg_label), 'pam'))\n        loss.update(\n            add_prefix(\n                super(DAHead, self).losses(cam_seg_logit, seg_label), 'cam'))\n        return loss\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/mmseg/models/decode_heads/decode_head.py",
    "content": "from abc import ABCMeta, abstractmethod\n\nimport torch\nimport torch.nn as nn\nfrom mmcv.cnn import normal_init\nfrom mmcv.runner import auto_fp16, force_fp32\n\nfrom mmseg.core import build_pixel_sampler\nfrom mmseg.ops import resize\nfrom ..builder import build_loss\nfrom ..losses import accuracy\n\n\nclass BaseDecodeHead(nn.Module, metaclass=ABCMeta):\n    \"\"\"Base class for BaseDecodeHead.\n\n    Args:\n        in_channels (int|Sequence[int]): Input channels.\n        channels (int): Channels after modules, before conv_seg.\n        num_classes (int): Number of classes.\n        dropout_ratio (float): Ratio of dropout layer. Default: 0.1.\n        conv_cfg (dict|None): Config of conv layers. Default: None.\n        norm_cfg (dict|None): Config of norm layers. Default: None.\n        act_cfg (dict): Config of activation layers.\n            Default: dict(type='ReLU')\n        in_index (int|Sequence[int]): Input feature index. Default: -1\n        input_transform (str|None): Transformation type of input features.\n            Options: 'resize_concat', 'multiple_select', None.\n            'resize_concat': Multiple feature maps will be resize to the\n                same size as first one and than concat together.\n                Usually used in FCN head of HRNet.\n            'multiple_select': Multiple feature maps will be bundle into\n                a list and passed into decode head.\n            None: Only one select feature map is allowed.\n            Default: None.\n        loss_decode (dict): Config of decode loss.\n            Default: dict(type='CrossEntropyLoss').\n        ignore_index (int | None): The label index to be ignored. When using\n            masked BCE loss, ignore_index should be set to None. Default: 255\n        sampler (dict|None): The config of segmentation map sampler.\n            Default: None.\n        align_corners (bool): align_corners argument of F.interpolate.\n            Default: False.\n    \"\"\"\n\n    def __init__(self,\n                 in_channels,\n                 channels,\n                 *,\n                 num_classes,\n                 dropout_ratio=0.1,\n                 conv_cfg=None,\n                 norm_cfg=None,\n                 act_cfg=dict(type='ReLU'),\n                 in_index=-1,\n                 input_transform=None,\n                 loss_decode=dict(\n                     type='CrossEntropyLoss',\n                     use_sigmoid=False,\n                     loss_weight=1.0),\n                 ignore_index=255,\n                 sampler=None,\n                 align_corners=False):\n        super(BaseDecodeHead, self).__init__()\n        self._init_inputs(in_channels, in_index, input_transform)\n        self.channels = channels\n        self.num_classes = num_classes\n        self.dropout_ratio = dropout_ratio\n        self.conv_cfg = conv_cfg\n        self.norm_cfg = norm_cfg\n        self.act_cfg = act_cfg\n        self.in_index = in_index\n        self.loss_decode = build_loss(loss_decode)\n        self.ignore_index = ignore_index\n        self.align_corners = align_corners\n        if sampler is not None:\n            self.sampler = build_pixel_sampler(sampler, context=self)\n        else:\n            self.sampler = None\n\n        self.conv_seg = nn.Conv2d(channels, num_classes, kernel_size=1)\n        if dropout_ratio > 0:\n            self.dropout = nn.Dropout2d(dropout_ratio)\n        else:\n            self.dropout = None\n        self.fp16_enabled = False\n\n    def extra_repr(self):\n        \"\"\"Extra repr.\"\"\"\n        s = f'input_transform={self.input_transform}, ' \\\n            f'ignore_index={self.ignore_index}, ' \\\n            f'align_corners={self.align_corners}'\n        return s\n\n    def _init_inputs(self, in_channels, in_index, input_transform):\n        \"\"\"Check and initialize input transforms.\n\n        The in_channels, in_index and input_transform must match.\n        Specifically, when input_transform is None, only single feature map\n        will be selected. So in_channels and in_index must be of type int.\n        When input_transform\n\n        Args:\n            in_channels (int|Sequence[int]): Input channels.\n            in_index (int|Sequence[int]): Input feature index.\n            input_transform (str|None): Transformation type of input features.\n                Options: 'resize_concat', 'multiple_select', None.\n                'resize_concat': Multiple feature maps will be resize to the\n                    same size as first one and than concat together.\n                    Usually used in FCN head of HRNet.\n                'multiple_select': Multiple feature maps will be bundle into\n                    a list and passed into decode head.\n                None: Only one select feature map is allowed.\n        \"\"\"\n\n        if input_transform is not None:\n            assert input_transform in ['resize_concat', 'multiple_select']\n        self.input_transform = input_transform\n        self.in_index = in_index\n        if input_transform is not None:\n            assert isinstance(in_channels, (list, tuple))\n            assert isinstance(in_index, (list, tuple))\n            assert len(in_channels) == len(in_index)\n            if input_transform == 'resize_concat':\n                self.in_channels = sum(in_channels)\n            else:\n                self.in_channels = in_channels\n        else:\n            assert isinstance(in_channels, int)\n            assert isinstance(in_index, int)\n            self.in_channels = in_channels\n\n    def init_weights(self):\n        \"\"\"Initialize weights of classification layer.\"\"\"\n        normal_init(self.conv_seg, mean=0, std=0.01)\n\n    def _transform_inputs(self, inputs):\n        \"\"\"Transform inputs for decoder.\n\n        Args:\n            inputs (list[Tensor]): List of multi-level img features.\n\n        Returns:\n            Tensor: The transformed inputs\n        \"\"\"\n\n        if self.input_transform == 'resize_concat':\n            inputs = [inputs[i] for i in self.in_index]\n            upsampled_inputs = [\n                resize(\n                    input=x,\n                    size=inputs[0].shape[2:],\n                    mode='bilinear',\n                    align_corners=self.align_corners) for x in inputs\n            ]\n            inputs = torch.cat(upsampled_inputs, dim=1)\n        elif self.input_transform == 'multiple_select':\n            inputs = [inputs[i] for i in self.in_index]\n        else:\n            inputs = inputs[self.in_index]\n\n        return inputs\n\n    @auto_fp16()\n    @abstractmethod\n    def forward(self, inputs):\n        \"\"\"Placeholder of forward function.\"\"\"\n        pass\n\n    def forward_train(self, inputs, img_metas, gt_semantic_seg, train_cfg):\n        \"\"\"Forward function for training.\n        Args:\n            inputs (list[Tensor]): List of multi-level img features.\n            img_metas (list[dict]): List of image info dict where each dict\n                has: 'img_shape', 'scale_factor', 'flip', and may also contain\n                'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.\n                For details on the values of these keys see\n                `mmseg/datasets/pipelines/formatting.py:Collect`.\n            gt_semantic_seg (Tensor): Semantic segmentation masks\n                used if the architecture supports semantic segmentation task.\n            train_cfg (dict): The training config.\n\n        Returns:\n            dict[str, Tensor]: a dictionary of loss components\n        \"\"\"\n        seg_logits = self.forward(inputs)\n        losses = self.losses(seg_logits, gt_semantic_seg)\n        return losses\n\n    def forward_test(self, inputs, img_metas, test_cfg):\n        \"\"\"Forward function for testing.\n\n        Args:\n            inputs (list[Tensor]): List of multi-level img features.\n            img_metas (list[dict]): List of image info dict where each dict\n                has: 'img_shape', 'scale_factor', 'flip', and may also contain\n                'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.\n                For details on the values of these keys see\n                `mmseg/datasets/pipelines/formatting.py:Collect`.\n            test_cfg (dict): The testing config.\n\n        Returns:\n            Tensor: Output segmentation map.\n        \"\"\"\n        return self.forward(inputs)\n\n    def cls_seg(self, feat):\n        \"\"\"Classify each pixel.\"\"\"\n        if self.dropout is not None:\n            feat = self.dropout(feat)\n        output = self.conv_seg(feat)\n        return output\n\n    @force_fp32(apply_to=('seg_logit', ))\n    def losses(self, seg_logit, seg_label):\n        \"\"\"Compute segmentation loss.\"\"\"\n        loss = dict()\n        seg_logit = resize(\n            input=seg_logit,\n            size=seg_label.shape[2:],\n            mode='bilinear',\n            align_corners=self.align_corners)\n        if self.sampler is not None:\n            seg_weight = self.sampler.sample(seg_logit, seg_label)\n        else:\n            seg_weight = None\n        seg_label = seg_label.squeeze(1)\n        loss['loss_seg'] = self.loss_decode(\n            seg_logit,\n            seg_label,\n            weight=seg_weight,\n            ignore_index=self.ignore_index)\n        loss['acc_seg'] = accuracy(seg_logit, seg_label)\n        return loss\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/mmseg/models/decode_heads/dm_head.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom mmcv.cnn import ConvModule, build_activation_layer, build_norm_layer\n\nfrom ..builder import HEADS\nfrom .decode_head import BaseDecodeHead\n\n\nclass DCM(nn.Module):\n    \"\"\"Dynamic Convolutional Module used in DMNet.\n\n    Args:\n        filter_size (int): The filter size of generated convolution kernel\n            used in Dynamic Convolutional Module.\n        fusion (bool): Add one conv to fuse DCM output feature.\n        in_channels (int): Input channels.\n        channels (int): Channels after modules, before conv_seg.\n        conv_cfg (dict | None): Config of conv layers.\n        norm_cfg (dict | None): Config of norm layers.\n        act_cfg (dict): Config of activation layers.\n    \"\"\"\n\n    def __init__(self, filter_size, fusion, in_channels, channels, conv_cfg,\n                 norm_cfg, act_cfg):\n        super(DCM, self).__init__()\n        self.filter_size = filter_size\n        self.fusion = fusion\n        self.in_channels = in_channels\n        self.channels = channels\n        self.conv_cfg = conv_cfg\n        self.norm_cfg = norm_cfg\n        self.act_cfg = act_cfg\n        self.filter_gen_conv = nn.Conv2d(self.in_channels, self.channels, 1, 1,\n                                         0)\n\n        self.input_redu_conv = ConvModule(\n            self.in_channels,\n            self.channels,\n            1,\n            conv_cfg=self.conv_cfg,\n            norm_cfg=self.norm_cfg,\n            act_cfg=self.act_cfg)\n\n        if self.norm_cfg is not None:\n            self.norm = build_norm_layer(self.norm_cfg, self.channels)[1]\n        else:\n            self.norm = None\n        self.activate = build_activation_layer(self.act_cfg)\n\n        if self.fusion:\n            self.fusion_conv = ConvModule(\n                self.channels,\n                self.channels,\n                1,\n                conv_cfg=self.conv_cfg,\n                norm_cfg=self.norm_cfg,\n                act_cfg=self.act_cfg)\n\n    def forward(self, x):\n        \"\"\"Forward function.\"\"\"\n        generted_filter = self.filter_gen_conv(\n            F.adaptive_avg_pool2d(x, self.filter_size))\n        x = self.input_redu_conv(x)\n        b, c, h, w = x.shape\n        # [1, b * c, h, w], c = self.channels\n        x = x.view(1, b * c, h, w)\n        # [b * c, 1, filter_size, filter_size]\n        generted_filter = generted_filter.view(b * c, 1, self.filter_size,\n                                               self.filter_size)\n        pad = (self.filter_size - 1) // 2\n        if (self.filter_size - 1) % 2 == 0:\n            p2d = (pad, pad, pad, pad)\n        else:\n            p2d = (pad + 1, pad, pad + 1, pad)\n        x = F.pad(input=x, pad=p2d, mode='constant', value=0)\n        # [1, b * c, h, w]\n        output = F.conv2d(input=x, weight=generted_filter, groups=b * c)\n        # [b, c, h, w]\n        output = output.view(b, c, h, w)\n        if self.norm is not None:\n            output = self.norm(output)\n        output = self.activate(output)\n\n        if self.fusion:\n            output = self.fusion_conv(output)\n\n        return output\n\n\n@HEADS.register_module()\nclass DMHead(BaseDecodeHead):\n    \"\"\"Dynamic Multi-scale Filters for Semantic Segmentation.\n\n    This head is the implementation of\n    `DMNet <https://openaccess.thecvf.com/content_ICCV_2019/papers/\\\n        He_Dynamic_Multi-Scale_Filters_for_Semantic_Segmentation_\\\n            ICCV_2019_paper.pdf>`_.\n\n    Args:\n        filter_sizes (tuple[int]): The size of generated convolutional filters\n            used in Dynamic Convolutional Module. Default: (1, 3, 5, 7).\n        fusion (bool): Add one conv to fuse DCM output feature.\n    \"\"\"\n\n    def __init__(self, filter_sizes=(1, 3, 5, 7), fusion=False, **kwargs):\n        super(DMHead, self).__init__(**kwargs)\n        assert isinstance(filter_sizes, (list, tuple))\n        self.filter_sizes = filter_sizes\n        self.fusion = fusion\n        dcm_modules = []\n        for filter_size in self.filter_sizes:\n            dcm_modules.append(\n                DCM(filter_size,\n                    self.fusion,\n                    self.in_channels,\n                    self.channels,\n                    conv_cfg=self.conv_cfg,\n                    norm_cfg=self.norm_cfg,\n                    act_cfg=self.act_cfg))\n        self.dcm_modules = nn.ModuleList(dcm_modules)\n        self.bottleneck = ConvModule(\n            self.in_channels + len(filter_sizes) * self.channels,\n            self.channels,\n            3,\n            padding=1,\n            conv_cfg=self.conv_cfg,\n            norm_cfg=self.norm_cfg,\n            act_cfg=self.act_cfg)\n\n    def forward(self, inputs):\n        \"\"\"Forward function.\"\"\"\n        x = self._transform_inputs(inputs)\n        dcm_outs = [x]\n        for dcm_module in self.dcm_modules:\n            dcm_outs.append(dcm_module(x))\n        dcm_outs = torch.cat(dcm_outs, dim=1)\n        output = self.bottleneck(dcm_outs)\n        output = self.cls_seg(output)\n        return output\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/mmseg/models/decode_heads/dnl_head.py",
    "content": "import torch\nfrom mmcv.cnn import NonLocal2d\nfrom torch import nn\n\nfrom ..builder import HEADS\nfrom .fcn_head import FCNHead\n\n\nclass DisentangledNonLocal2d(NonLocal2d):\n    \"\"\"Disentangled Non-Local Blocks.\n\n    Args:\n        temperature (float): Temperature to adjust attention. Default: 0.05\n    \"\"\"\n\n    def __init__(self, *arg, temperature, **kwargs):\n        super().__init__(*arg, **kwargs)\n        self.temperature = temperature\n        self.conv_mask = nn.Conv2d(self.in_channels, 1, kernel_size=1)\n\n    def embedded_gaussian(self, theta_x, phi_x):\n        \"\"\"Embedded gaussian with temperature.\"\"\"\n\n        # NonLocal2d pairwise_weight: [N, HxW, HxW]\n        pairwise_weight = torch.matmul(theta_x, phi_x)\n        if self.use_scale:\n            # theta_x.shape[-1] is `self.inter_channels`\n            pairwise_weight /= theta_x.shape[-1]**0.5\n        pairwise_weight /= self.temperature\n        pairwise_weight = pairwise_weight.softmax(dim=-1)\n        return pairwise_weight\n\n    def forward(self, x):\n        # x: [N, C, H, W]\n        n = x.size(0)\n\n        # g_x: [N, HxW, C]\n        g_x = self.g(x).view(n, self.inter_channels, -1)\n        g_x = g_x.permute(0, 2, 1)\n\n        # theta_x: [N, HxW, C], phi_x: [N, C, HxW]\n        if self.mode == 'gaussian':\n            theta_x = x.view(n, self.in_channels, -1)\n            theta_x = theta_x.permute(0, 2, 1)\n            if self.sub_sample:\n                phi_x = self.phi(x).view(n, self.in_channels, -1)\n            else:\n                phi_x = x.view(n, self.in_channels, -1)\n        elif self.mode == 'concatenation':\n            theta_x = self.theta(x).view(n, self.inter_channels, -1, 1)\n            phi_x = self.phi(x).view(n, self.inter_channels, 1, -1)\n        else:\n            theta_x = self.theta(x).view(n, self.inter_channels, -1)\n            theta_x = theta_x.permute(0, 2, 1)\n            phi_x = self.phi(x).view(n, self.inter_channels, -1)\n\n        # subtract mean\n        theta_x -= theta_x.mean(dim=-2, keepdim=True)\n        phi_x -= phi_x.mean(dim=-1, keepdim=True)\n\n        pairwise_func = getattr(self, self.mode)\n        # pairwise_weight: [N, HxW, HxW]\n        pairwise_weight = pairwise_func(theta_x, phi_x)\n\n        # y: [N, HxW, C]\n        y = torch.matmul(pairwise_weight, g_x)\n        # y: [N, C, H, W]\n        y = y.permute(0, 2, 1).contiguous().reshape(n, self.inter_channels,\n                                                    *x.size()[2:])\n\n        # unary_mask: [N, 1, HxW]\n        unary_mask = self.conv_mask(x)\n        unary_mask = unary_mask.view(n, 1, -1)\n        unary_mask = unary_mask.softmax(dim=-1)\n        # unary_x: [N, 1, C]\n        unary_x = torch.matmul(unary_mask, g_x)\n        # unary_x: [N, C, 1, 1]\n        unary_x = unary_x.permute(0, 2, 1).contiguous().reshape(\n            n, self.inter_channels, 1, 1)\n\n        output = x + self.conv_out(y + unary_x)\n\n        return output\n\n\n@HEADS.register_module()\nclass DNLHead(FCNHead):\n    \"\"\"Disentangled Non-Local Neural Networks.\n\n    This head is the implementation of `DNLNet\n    <https://arxiv.org/abs/2006.06668>`_.\n\n    Args:\n        reduction (int): Reduction factor of projection transform. Default: 2.\n        use_scale (bool): Whether to scale pairwise_weight by\n            sqrt(1/inter_channels). Default: False.\n        mode (str): The nonlocal mode. Options are 'embedded_gaussian',\n            'dot_product'. Default: 'embedded_gaussian.'.\n        temperature (float): Temperature to adjust attention. Default: 0.05\n    \"\"\"\n\n    def __init__(self,\n                 reduction=2,\n                 use_scale=True,\n                 mode='embedded_gaussian',\n                 temperature=0.05,\n                 **kwargs):\n        super(DNLHead, self).__init__(num_convs=2, **kwargs)\n        self.reduction = reduction\n        self.use_scale = use_scale\n        self.mode = mode\n        self.temperature = temperature\n        self.dnl_block = DisentangledNonLocal2d(\n            in_channels=self.channels,\n            reduction=self.reduction,\n            use_scale=self.use_scale,\n            conv_cfg=self.conv_cfg,\n            norm_cfg=self.norm_cfg,\n            mode=self.mode,\n            temperature=self.temperature)\n\n    def forward(self, inputs):\n        \"\"\"Forward function.\"\"\"\n        x = self._transform_inputs(inputs)\n        output = self.convs[0](x)\n        output = self.dnl_block(output)\n        output = self.convs[1](output)\n        if self.concat_input:\n            output = self.conv_cat(torch.cat([x, output], dim=1))\n        output = self.cls_seg(output)\n        return output\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/mmseg/models/decode_heads/ema_head.py",
    "content": "import math\n\nimport torch\nimport torch.distributed as dist\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom mmcv.cnn import ConvModule\n\nfrom ..builder import HEADS\nfrom .decode_head import BaseDecodeHead\n\n\ndef reduce_mean(tensor):\n    \"\"\"Reduce mean when distributed training.\"\"\"\n    if not (dist.is_available() and dist.is_initialized()):\n        return tensor\n    tensor = tensor.clone()\n    dist.all_reduce(tensor.div_(dist.get_world_size()), op=dist.ReduceOp.SUM)\n    return tensor\n\n\nclass EMAModule(nn.Module):\n    \"\"\"Expectation Maximization Attention Module used in EMANet.\n\n    Args:\n        channels (int): Channels of the whole module.\n        num_bases (int): Number of bases.\n        num_stages (int): Number of the EM iterations.\n    \"\"\"\n\n    def __init__(self, channels, num_bases, num_stages, momentum):\n        super(EMAModule, self).__init__()\n        assert num_stages >= 1, 'num_stages must be at least 1!'\n        self.num_bases = num_bases\n        self.num_stages = num_stages\n        self.momentum = momentum\n\n        bases = torch.zeros(1, channels, self.num_bases)\n        bases.normal_(0, math.sqrt(2. / self.num_bases))\n        # [1, channels, num_bases]\n        bases = F.normalize(bases, dim=1, p=2)\n        self.register_buffer('bases', bases)\n\n    def forward(self, feats):\n        \"\"\"Forward function.\"\"\"\n        batch_size, channels, height, width = feats.size()\n        # [batch_size, channels, height*width]\n        feats = feats.view(batch_size, channels, height * width)\n        # [batch_size, channels, num_bases]\n        bases = self.bases.repeat(batch_size, 1, 1)\n\n        with torch.no_grad():\n            for i in range(self.num_stages):\n                # [batch_size, height*width, num_bases]\n                attention = torch.einsum('bcn,bck->bnk', feats, bases)\n                attention = F.softmax(attention, dim=2)\n                # l1 norm\n                attention_normed = F.normalize(attention, dim=1, p=1)\n                # [batch_size, channels, num_bases]\n                bases = torch.einsum('bcn,bnk->bck', feats, attention_normed)\n                # l2 norm\n                bases = F.normalize(bases, dim=1, p=2)\n\n        feats_recon = torch.einsum('bck,bnk->bcn', bases, attention)\n        feats_recon = feats_recon.view(batch_size, channels, height, width)\n\n        if self.training:\n            bases = bases.mean(dim=0, keepdim=True)\n            bases = reduce_mean(bases)\n            # l2 norm\n            bases = F.normalize(bases, dim=1, p=2)\n            self.bases = (1 -\n                          self.momentum) * self.bases + self.momentum * bases\n\n        return feats_recon\n\n\n@HEADS.register_module()\nclass EMAHead(BaseDecodeHead):\n    \"\"\"Expectation Maximization Attention Networks for Semantic Segmentation.\n\n    This head is the implementation of `EMANet\n    <https://arxiv.org/abs/1907.13426>`_.\n\n    Args:\n        ema_channels (int): EMA module channels\n        num_bases (int): Number of bases.\n        num_stages (int): Number of the EM iterations.\n        concat_input (bool): Whether concat the input and output of convs\n            before classification layer. Default: True\n        momentum (float): Momentum to update the base. Default: 0.1.\n    \"\"\"\n\n    def __init__(self,\n                 ema_channels,\n                 num_bases,\n                 num_stages,\n                 concat_input=True,\n                 momentum=0.1,\n                 **kwargs):\n        super(EMAHead, self).__init__(**kwargs)\n        self.ema_channels = ema_channels\n        self.num_bases = num_bases\n        self.num_stages = num_stages\n        self.concat_input = concat_input\n        self.momentum = momentum\n        self.ema_module = EMAModule(self.ema_channels, self.num_bases,\n                                    self.num_stages, self.momentum)\n\n        self.ema_in_conv = ConvModule(\n            self.in_channels,\n            self.ema_channels,\n            3,\n            padding=1,\n            conv_cfg=self.conv_cfg,\n            norm_cfg=self.norm_cfg,\n            act_cfg=self.act_cfg)\n        # project (0, inf) -> (-inf, inf)\n        self.ema_mid_conv = ConvModule(\n            self.ema_channels,\n            self.ema_channels,\n            1,\n            conv_cfg=self.conv_cfg,\n            norm_cfg=None,\n            act_cfg=None)\n        for param in self.ema_mid_conv.parameters():\n            param.requires_grad = False\n\n        self.ema_out_conv = ConvModule(\n            self.ema_channels,\n            self.ema_channels,\n            1,\n            conv_cfg=self.conv_cfg,\n            norm_cfg=self.norm_cfg,\n            act_cfg=None)\n        self.bottleneck = ConvModule(\n            self.ema_channels,\n            self.channels,\n            3,\n            padding=1,\n            conv_cfg=self.conv_cfg,\n            norm_cfg=self.norm_cfg,\n            act_cfg=self.act_cfg)\n        if self.concat_input:\n            self.conv_cat = ConvModule(\n                self.in_channels + self.channels,\n                self.channels,\n                kernel_size=3,\n                padding=1,\n                conv_cfg=self.conv_cfg,\n                norm_cfg=self.norm_cfg,\n                act_cfg=self.act_cfg)\n\n    def forward(self, inputs):\n        \"\"\"Forward function.\"\"\"\n        x = self._transform_inputs(inputs)\n        feats = self.ema_in_conv(x)\n        identity = feats\n        feats = self.ema_mid_conv(feats)\n        recon = self.ema_module(feats)\n        recon = F.relu(recon, inplace=True)\n        recon = self.ema_out_conv(recon)\n        output = F.relu(identity + recon, inplace=True)\n        output = self.bottleneck(output)\n        if self.concat_input:\n            output = self.conv_cat(torch.cat([x, output], dim=1))\n        output = self.cls_seg(output)\n        return output\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/mmseg/models/decode_heads/enc_head.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom mmcv.cnn import ConvModule, build_norm_layer\n\nfrom mmseg.ops import Encoding, resize\nfrom ..builder import HEADS, build_loss\nfrom .decode_head import BaseDecodeHead\n\n\nclass EncModule(nn.Module):\n    \"\"\"Encoding Module used in EncNet.\n\n    Args:\n        in_channels (int): Input channels.\n        num_codes (int): Number of code words.\n        conv_cfg (dict|None): Config of conv layers.\n        norm_cfg (dict|None): Config of norm layers.\n        act_cfg (dict): Config of activation layers.\n    \"\"\"\n\n    def __init__(self, in_channels, num_codes, conv_cfg, norm_cfg, act_cfg):\n        super(EncModule, self).__init__()\n        self.encoding_project = ConvModule(\n            in_channels,\n            in_channels,\n            1,\n            conv_cfg=conv_cfg,\n            norm_cfg=norm_cfg,\n            act_cfg=act_cfg)\n        # TODO: resolve this hack\n        # change to 1d\n        if norm_cfg is not None:\n            encoding_norm_cfg = norm_cfg.copy()\n            if encoding_norm_cfg['type'] in ['BN', 'IN']:\n                encoding_norm_cfg['type'] += '1d'\n            else:\n                encoding_norm_cfg['type'] = encoding_norm_cfg['type'].replace(\n                    '2d', '1d')\n        else:\n            # fallback to BN1d\n            encoding_norm_cfg = dict(type='BN1d')\n        self.encoding = nn.Sequential(\n            Encoding(channels=in_channels, num_codes=num_codes),\n            build_norm_layer(encoding_norm_cfg, num_codes)[1],\n            nn.ReLU(inplace=True))\n        self.fc = nn.Sequential(\n            nn.Linear(in_channels, in_channels), nn.Sigmoid())\n\n    def forward(self, x):\n        \"\"\"Forward function.\"\"\"\n        encoding_projection = self.encoding_project(x)\n        encoding_feat = self.encoding(encoding_projection).mean(dim=1)\n        batch_size, channels, _, _ = x.size()\n        gamma = self.fc(encoding_feat)\n        y = gamma.view(batch_size, channels, 1, 1)\n        output = F.relu_(x + x * y)\n        return encoding_feat, output\n\n\n@HEADS.register_module()\nclass EncHead(BaseDecodeHead):\n    \"\"\"Context Encoding for Semantic Segmentation.\n\n    This head is the implementation of `EncNet\n    <https://arxiv.org/abs/1803.08904>`_.\n\n    Args:\n        num_codes (int): Number of code words. Default: 32.\n        use_se_loss (bool): Whether use Semantic Encoding Loss (SE-loss) to\n            regularize the training. Default: True.\n        add_lateral (bool): Whether use lateral connection to fuse features.\n            Default: False.\n        loss_se_decode (dict): Config of decode loss.\n            Default: dict(type='CrossEntropyLoss', use_sigmoid=True).\n    \"\"\"\n\n    def __init__(self,\n                 num_codes=32,\n                 use_se_loss=True,\n                 add_lateral=False,\n                 loss_se_decode=dict(\n                     type='CrossEntropyLoss',\n                     use_sigmoid=True,\n                     loss_weight=0.2),\n                 **kwargs):\n        super(EncHead, self).__init__(\n            input_transform='multiple_select', **kwargs)\n        self.use_se_loss = use_se_loss\n        self.add_lateral = add_lateral\n        self.num_codes = num_codes\n        self.bottleneck = ConvModule(\n            self.in_channels[-1],\n            self.channels,\n            3,\n            padding=1,\n            conv_cfg=self.conv_cfg,\n            norm_cfg=self.norm_cfg,\n            act_cfg=self.act_cfg)\n        if add_lateral:\n            self.lateral_convs = nn.ModuleList()\n            for in_channels in self.in_channels[:-1]:  # skip the last one\n                self.lateral_convs.append(\n                    ConvModule(\n                        in_channels,\n                        self.channels,\n                        1,\n                        conv_cfg=self.conv_cfg,\n                        norm_cfg=self.norm_cfg,\n                        act_cfg=self.act_cfg))\n            self.fusion = ConvModule(\n                len(self.in_channels) * self.channels,\n                self.channels,\n                3,\n                padding=1,\n                conv_cfg=self.conv_cfg,\n                norm_cfg=self.norm_cfg,\n                act_cfg=self.act_cfg)\n        self.enc_module = EncModule(\n            self.channels,\n            num_codes=num_codes,\n            conv_cfg=self.conv_cfg,\n            norm_cfg=self.norm_cfg,\n            act_cfg=self.act_cfg)\n        if self.use_se_loss:\n            self.loss_se_decode = build_loss(loss_se_decode)\n            self.se_layer = nn.Linear(self.channels, self.num_classes)\n\n    def forward(self, inputs):\n        \"\"\"Forward function.\"\"\"\n        inputs = self._transform_inputs(inputs)\n        feat = self.bottleneck(inputs[-1])\n        if self.add_lateral:\n            laterals = [\n                resize(\n                    lateral_conv(inputs[i]),\n                    size=feat.shape[2:],\n                    mode='bilinear',\n                    align_corners=self.align_corners)\n                for i, lateral_conv in enumerate(self.lateral_convs)\n            ]\n            feat = self.fusion(torch.cat([feat, *laterals], 1))\n        encode_feat, output = self.enc_module(feat)\n        output = self.cls_seg(output)\n        if self.use_se_loss:\n            se_output = self.se_layer(encode_feat)\n            return output, se_output\n        else:\n            return output\n\n    def forward_test(self, inputs, img_metas, test_cfg):\n        \"\"\"Forward function for testing, ignore se_loss.\"\"\"\n        if self.use_se_loss:\n            return self.forward(inputs)[0]\n        else:\n            return self.forward(inputs)\n\n    @staticmethod\n    def _convert_to_onehot_labels(seg_label, num_classes):\n        \"\"\"Convert segmentation label to onehot.\n\n        Args:\n            seg_label (Tensor): Segmentation label of shape (N, H, W).\n            num_classes (int): Number of classes.\n\n        Returns:\n            Tensor: Onehot labels of shape (N, num_classes).\n        \"\"\"\n\n        batch_size = seg_label.size(0)\n        onehot_labels = seg_label.new_zeros((batch_size, num_classes))\n        for i in range(batch_size):\n            hist = seg_label[i].float().histc(\n                bins=num_classes, min=0, max=num_classes - 1)\n            onehot_labels[i] = hist > 0\n        return onehot_labels\n\n    def losses(self, seg_logit, seg_label):\n        \"\"\"Compute segmentation and semantic encoding loss.\"\"\"\n        seg_logit, se_seg_logit = seg_logit\n        loss = dict()\n        loss.update(super(EncHead, self).losses(seg_logit, seg_label))\n        se_loss = self.loss_se_decode(\n            se_seg_logit,\n            self._convert_to_onehot_labels(seg_label, self.num_classes))\n        loss['loss_se'] = se_loss\n        return loss\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/mmseg/models/decode_heads/fcn_head.py",
    "content": "import torch\nimport torch.nn as nn\nfrom mmcv.cnn import ConvModule\n\nfrom ..builder import HEADS\nfrom .decode_head import BaseDecodeHead\n\n\n@HEADS.register_module()\nclass FCNHead(BaseDecodeHead):\n    \"\"\"Fully Convolution Networks for Semantic Segmentation.\n\n    This head is implemented of `FCNNet <https://arxiv.org/abs/1411.4038>`_.\n\n    Args:\n        num_convs (int): Number of convs in the head. Default: 2.\n        kernel_size (int): The kernel size for convs in the head. Default: 3.\n        concat_input (bool): Whether concat the input and output of convs\n            before classification layer.\n    \"\"\"\n\n    def __init__(self,\n                 num_convs=2,\n                 kernel_size=3,\n                 concat_input=True,\n                 **kwargs):\n        assert num_convs >= 0\n        self.num_convs = num_convs\n        self.concat_input = concat_input\n        self.kernel_size = kernel_size\n        super(FCNHead, self).__init__(**kwargs)\n        if num_convs == 0:\n            assert self.in_channels == self.channels\n\n        convs = []\n        convs.append(\n            ConvModule(\n                self.in_channels,\n                self.channels,\n                kernel_size=kernel_size,\n                padding=kernel_size // 2,\n                conv_cfg=self.conv_cfg,\n                norm_cfg=self.norm_cfg,\n                act_cfg=self.act_cfg))\n        for i in range(num_convs - 1):\n            convs.append(\n                ConvModule(\n                    self.channels,\n                    self.channels,\n                    kernel_size=kernel_size,\n                    padding=kernel_size // 2,\n                    conv_cfg=self.conv_cfg,\n                    norm_cfg=self.norm_cfg,\n                    act_cfg=self.act_cfg))\n        if num_convs == 0:\n            self.convs = nn.Identity()\n        else:\n            self.convs = nn.Sequential(*convs)\n        if self.concat_input:\n            self.conv_cat = ConvModule(\n                self.in_channels + self.channels,\n                self.channels,\n                kernel_size=kernel_size,\n                padding=kernel_size // 2,\n                conv_cfg=self.conv_cfg,\n                norm_cfg=self.norm_cfg,\n                act_cfg=self.act_cfg)\n\n    def forward(self, inputs):\n        \"\"\"Forward function.\"\"\"\n        x = self._transform_inputs(inputs)\n        output = self.convs(x)\n        if self.concat_input:\n            output = self.conv_cat(torch.cat([x, output], dim=1))\n        output = self.cls_seg(output)\n        return output\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/mmseg/models/decode_heads/fpn_head.py",
    "content": "import numpy as np\nimport torch.nn as nn\nfrom mmcv.cnn import ConvModule\n\nfrom mmseg.ops import resize\nfrom ..builder import HEADS\nfrom .decode_head import BaseDecodeHead\n\n\n@HEADS.register_module()\nclass FPNHead(BaseDecodeHead):\n    \"\"\"Panoptic Feature Pyramid Networks.\n\n    This head is the implementation of `Semantic FPN\n    <https://arxiv.org/abs/1901.02446>`_.\n\n    Args:\n        feature_strides (tuple[int]): The strides for input feature maps.\n            stack_lateral. All strides suppose to be power of 2. The first\n            one is of largest resolution.\n    \"\"\"\n\n    def __init__(self, feature_strides, **kwargs):\n        super(FPNHead, self).__init__(\n            input_transform='multiple_select', **kwargs)\n        assert len(feature_strides) == len(self.in_channels)\n        assert min(feature_strides) == feature_strides[0]\n        self.feature_strides = feature_strides\n\n        self.scale_heads = nn.ModuleList()\n        for i in range(len(feature_strides)):\n            head_length = max(\n                1,\n                int(np.log2(feature_strides[i]) - np.log2(feature_strides[0])))\n            scale_head = []\n            for k in range(head_length):\n                scale_head.append(\n                    ConvModule(\n                        self.in_channels[i] if k == 0 else self.channels,\n                        self.channels,\n                        3,\n                        padding=1,\n                        conv_cfg=self.conv_cfg,\n                        norm_cfg=self.norm_cfg,\n                        act_cfg=self.act_cfg))\n                if feature_strides[i] != feature_strides[0]:\n                    scale_head.append(\n                        nn.Upsample(\n                            scale_factor=2,\n                            mode='bilinear',\n                            align_corners=self.align_corners))\n            self.scale_heads.append(nn.Sequential(*scale_head))\n\n    def forward(self, inputs):\n\n        x = self._transform_inputs(inputs)\n\n        output = self.scale_heads[0](x[0])\n        for i in range(1, len(self.feature_strides)):\n            # non inplace\n            output = output + resize(\n                self.scale_heads[i](x[i]),\n                size=output.shape[2:],\n                mode='bilinear',\n                align_corners=self.align_corners)\n\n        output = self.cls_seg(output)\n        return output\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/mmseg/models/decode_heads/gc_head.py",
    "content": "import torch\nfrom mmcv.cnn import ContextBlock\n\nfrom ..builder import HEADS\nfrom .fcn_head import FCNHead\n\n\n@HEADS.register_module()\nclass GCHead(FCNHead):\n    \"\"\"GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond.\n\n    This head is the implementation of `GCNet\n    <https://arxiv.org/abs/1904.11492>`_.\n\n    Args:\n        ratio (float): Multiplier of channels ratio. Default: 1/4.\n        pooling_type (str): The pooling type of context aggregation.\n            Options are 'att', 'avg'. Default: 'avg'.\n        fusion_types (tuple[str]): The fusion type for feature fusion.\n            Options are 'channel_add', 'channel_mul'. Defautl: ('channel_add',)\n    \"\"\"\n\n    def __init__(self,\n                 ratio=1 / 4.,\n                 pooling_type='att',\n                 fusion_types=('channel_add', ),\n                 **kwargs):\n        super(GCHead, self).__init__(num_convs=2, **kwargs)\n        self.ratio = ratio\n        self.pooling_type = pooling_type\n        self.fusion_types = fusion_types\n        self.gc_block = ContextBlock(\n            in_channels=self.channels,\n            ratio=self.ratio,\n            pooling_type=self.pooling_type,\n            fusion_types=self.fusion_types)\n\n    def forward(self, inputs):\n        \"\"\"Forward function.\"\"\"\n        x = self._transform_inputs(inputs)\n        output = self.convs[0](x)\n        output = self.gc_block(output)\n        output = self.convs[1](output)\n        if self.concat_input:\n            output = self.conv_cat(torch.cat([x, output], dim=1))\n        output = self.cls_seg(output)\n        return output\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/mmseg/models/decode_heads/lraspp_head.py",
    "content": "import torch\nimport torch.nn as nn\nfrom mmcv import is_tuple_of\nfrom mmcv.cnn import ConvModule\n\nfrom mmseg.ops import resize\nfrom ..builder import HEADS\nfrom .decode_head import BaseDecodeHead\n\n\n@HEADS.register_module()\nclass LRASPPHead(BaseDecodeHead):\n    \"\"\"Lite R-ASPP (LRASPP) head is proposed in Searching for MobileNetV3.\n\n    This head is the improved implementation of `Searching for MobileNetV3\n    <https://ieeexplore.ieee.org/document/9008835>`_.\n\n    Args:\n        branch_channels (tuple[int]): The number of output channels in every\n            each branch. Default: (32, 64).\n    \"\"\"\n\n    def __init__(self, branch_channels=(32, 64), **kwargs):\n        super(LRASPPHead, self).__init__(**kwargs)\n        if self.input_transform != 'multiple_select':\n            raise ValueError('in Lite R-ASPP (LRASPP) head, input_transform '\n                             f'must be \\'multiple_select\\'. But received '\n                             f'\\'{self.input_transform}\\'')\n        assert is_tuple_of(branch_channels, int)\n        assert len(branch_channels) == len(self.in_channels) - 1\n        self.branch_channels = branch_channels\n\n        self.convs = nn.Sequential()\n        self.conv_ups = nn.Sequential()\n        for i in range(len(branch_channels)):\n            self.convs.add_module(\n                f'conv{i}',\n                nn.Conv2d(\n                    self.in_channels[i], branch_channels[i], 1, bias=False))\n            self.conv_ups.add_module(\n                f'conv_up{i}',\n                ConvModule(\n                    self.channels + branch_channels[i],\n                    self.channels,\n                    1,\n                    norm_cfg=self.norm_cfg,\n                    act_cfg=self.act_cfg,\n                    bias=False))\n\n        self.conv_up_input = nn.Conv2d(self.channels, self.channels, 1)\n\n        self.aspp_conv = ConvModule(\n            self.in_channels[-1],\n            self.channels,\n            1,\n            norm_cfg=self.norm_cfg,\n            act_cfg=self.act_cfg,\n            bias=False)\n        self.image_pool = nn.Sequential(\n            nn.AvgPool2d(kernel_size=49, stride=(16, 20)),\n            ConvModule(\n                self.in_channels[2],\n                self.channels,\n                1,\n                act_cfg=dict(type='Sigmoid'),\n                bias=False))\n\n    def forward(self, inputs):\n        \"\"\"Forward function.\"\"\"\n        inputs = self._transform_inputs(inputs)\n\n        x = inputs[-1]\n\n        x = self.aspp_conv(x) * resize(\n            self.image_pool(x),\n            size=x.size()[2:],\n            mode='bilinear',\n            align_corners=self.align_corners)\n        x = self.conv_up_input(x)\n\n        for i in range(len(self.branch_channels) - 1, -1, -1):\n            x = resize(\n                x,\n                size=inputs[i].size()[2:],\n                mode='bilinear',\n                align_corners=self.align_corners)\n            x = torch.cat([x, self.convs[i](inputs[i])], 1)\n            x = self.conv_ups[i](x)\n\n        return self.cls_seg(x)\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/mmseg/models/decode_heads/nl_head.py",
    "content": "import torch\nfrom mmcv.cnn import NonLocal2d\n\nfrom ..builder import HEADS\nfrom .fcn_head import FCNHead\n\n\n@HEADS.register_module()\nclass NLHead(FCNHead):\n    \"\"\"Non-local Neural Networks.\n\n    This head is the implementation of `NLNet\n    <https://arxiv.org/abs/1711.07971>`_.\n\n    Args:\n        reduction (int): Reduction factor of projection transform. Default: 2.\n        use_scale (bool): Whether to scale pairwise_weight by\n            sqrt(1/inter_channels). Default: True.\n        mode (str): The nonlocal mode. Options are 'embedded_gaussian',\n            'dot_product'. Default: 'embedded_gaussian.'.\n    \"\"\"\n\n    def __init__(self,\n                 reduction=2,\n                 use_scale=True,\n                 mode='embedded_gaussian',\n                 **kwargs):\n        super(NLHead, self).__init__(num_convs=2, **kwargs)\n        self.reduction = reduction\n        self.use_scale = use_scale\n        self.mode = mode\n        self.nl_block = NonLocal2d(\n            in_channels=self.channels,\n            reduction=self.reduction,\n            use_scale=self.use_scale,\n            conv_cfg=self.conv_cfg,\n            norm_cfg=self.norm_cfg,\n            mode=self.mode)\n\n    def forward(self, inputs):\n        \"\"\"Forward function.\"\"\"\n        x = self._transform_inputs(inputs)\n        output = self.convs[0](x)\n        output = self.nl_block(output)\n        output = self.convs[1](output)\n        if self.concat_input:\n            output = self.conv_cat(torch.cat([x, output], dim=1))\n        output = self.cls_seg(output)\n        return output\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/mmseg/models/decode_heads/ocr_head.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom mmcv.cnn import ConvModule\n\nfrom mmseg.ops import resize\nfrom ..builder import HEADS\nfrom ..utils import SelfAttentionBlock as _SelfAttentionBlock\nfrom .cascade_decode_head import BaseCascadeDecodeHead\n\n\nclass SpatialGatherModule(nn.Module):\n    \"\"\"Aggregate the context features according to the initial predicted\n    probability distribution.\n\n    Employ the soft-weighted method to aggregate the context.\n    \"\"\"\n\n    def __init__(self, scale):\n        super(SpatialGatherModule, self).__init__()\n        self.scale = scale\n\n    def forward(self, feats, probs):\n        \"\"\"Forward function.\"\"\"\n        batch_size, num_classes, height, width = probs.size()\n        channels = feats.size(1)\n        probs = probs.view(batch_size, num_classes, -1)\n        feats = feats.view(batch_size, channels, -1)\n        # [batch_size, height*width, num_classes]\n        feats = feats.permute(0, 2, 1)\n        # [batch_size, channels, height*width]\n        probs = F.softmax(self.scale * probs, dim=2)\n        # [batch_size, channels, num_classes]\n        ocr_context = torch.matmul(probs, feats)\n        ocr_context = ocr_context.permute(0, 2, 1).contiguous().unsqueeze(3)\n        return ocr_context\n\n\nclass ObjectAttentionBlock(_SelfAttentionBlock):\n    \"\"\"Make a OCR used SelfAttentionBlock.\"\"\"\n\n    def __init__(self, in_channels, channels, scale, conv_cfg, norm_cfg,\n                 act_cfg):\n        if scale > 1:\n            query_downsample = nn.MaxPool2d(kernel_size=scale)\n        else:\n            query_downsample = None\n        super(ObjectAttentionBlock, self).__init__(\n            key_in_channels=in_channels,\n            query_in_channels=in_channels,\n            channels=channels,\n            out_channels=in_channels,\n            share_key_query=False,\n            query_downsample=query_downsample,\n            key_downsample=None,\n            key_query_num_convs=2,\n            key_query_norm=True,\n            value_out_num_convs=1,\n            value_out_norm=True,\n            matmul_norm=True,\n            with_out=True,\n            conv_cfg=conv_cfg,\n            norm_cfg=norm_cfg,\n            act_cfg=act_cfg)\n        self.bottleneck = ConvModule(\n            in_channels * 2,\n            in_channels,\n            1,\n            conv_cfg=self.conv_cfg,\n            norm_cfg=self.norm_cfg,\n            act_cfg=self.act_cfg)\n\n    def forward(self, query_feats, key_feats):\n        \"\"\"Forward function.\"\"\"\n        context = super(ObjectAttentionBlock,\n                        self).forward(query_feats, key_feats)\n        output = self.bottleneck(torch.cat([context, query_feats], dim=1))\n        if self.query_downsample is not None:\n            output = resize(query_feats)\n\n        return output\n\n\n@HEADS.register_module()\nclass OCRHead(BaseCascadeDecodeHead):\n    \"\"\"Object-Contextual Representations for Semantic Segmentation.\n\n    This head is the implementation of `OCRNet\n    <https://arxiv.org/abs/1909.11065>`_.\n\n    Args:\n        ocr_channels (int): The intermediate channels of OCR block.\n        scale (int): The scale of probability map in SpatialGatherModule in\n            Default: 1.\n    \"\"\"\n\n    def __init__(self, ocr_channels, scale=1, **kwargs):\n        super(OCRHead, self).__init__(**kwargs)\n        self.ocr_channels = ocr_channels\n        self.scale = scale\n        self.object_context_block = ObjectAttentionBlock(\n            self.channels,\n            self.ocr_channels,\n            self.scale,\n            conv_cfg=self.conv_cfg,\n            norm_cfg=self.norm_cfg,\n            act_cfg=self.act_cfg)\n        self.spatial_gather_module = SpatialGatherModule(self.scale)\n\n        self.bottleneck = ConvModule(\n            self.in_channels,\n            self.channels,\n            3,\n            padding=1,\n            conv_cfg=self.conv_cfg,\n            norm_cfg=self.norm_cfg,\n            act_cfg=self.act_cfg)\n\n    def forward(self, inputs, prev_output):\n        \"\"\"Forward function.\"\"\"\n        x = self._transform_inputs(inputs)\n        feats = self.bottleneck(x)\n        context = self.spatial_gather_module(feats, prev_output)\n        object_context = self.object_context_block(feats, context)\n        output = self.cls_seg(object_context)\n\n        return output\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/mmseg/models/decode_heads/point_head.py",
    "content": "# Modified from https://github.com/facebookresearch/detectron2/tree/master/projects/PointRend/point_head/point_head.py  # noqa\n\nimport torch\nimport torch.nn as nn\nfrom mmcv.cnn import ConvModule, normal_init\nfrom mmcv.ops import point_sample\n\nfrom mmseg.models.builder import HEADS\nfrom mmseg.ops import resize\nfrom ..losses import accuracy\nfrom .cascade_decode_head import BaseCascadeDecodeHead\n\n\ndef calculate_uncertainty(seg_logits):\n    \"\"\"Estimate uncertainty based on seg logits.\n\n    For each location of the prediction ``seg_logits`` we estimate\n    uncertainty as the difference between top first and top second\n    predicted logits.\n\n    Args:\n        seg_logits (Tensor): Semantic segmentation logits,\n            shape (batch_size, num_classes, height, width).\n\n    Returns:\n        scores (Tensor): T uncertainty scores with the most uncertain\n            locations having the highest uncertainty score, shape (\n            batch_size, 1, height, width)\n    \"\"\"\n    top2_scores = torch.topk(seg_logits, k=2, dim=1)[0]\n    return (top2_scores[:, 1] - top2_scores[:, 0]).unsqueeze(1)\n\n\n@HEADS.register_module()\nclass PointHead(BaseCascadeDecodeHead):\n    \"\"\"A mask point head use in PointRend.\n\n    ``PointHead`` use shared multi-layer perceptron (equivalent to\n    nn.Conv1d) to predict the logit of input points. The fine-grained feature\n    and coarse feature will be concatenate together for predication.\n\n    Args:\n        num_fcs (int): Number of fc layers in the head. Default: 3.\n        in_channels (int): Number of input channels. Default: 256.\n        fc_channels (int): Number of fc channels. Default: 256.\n        num_classes (int): Number of classes for logits. Default: 80.\n        class_agnostic (bool): Whether use class agnostic classification.\n            If so, the output channels of logits will be 1. Default: False.\n        coarse_pred_each_layer (bool): Whether concatenate coarse feature with\n            the output of each fc layer. Default: True.\n        conv_cfg (dict|None): Dictionary to construct and config conv layer.\n            Default: dict(type='Conv1d'))\n        norm_cfg (dict|None): Dictionary to construct and config norm layer.\n            Default: None.\n        loss_point (dict): Dictionary to construct and config loss layer of\n            point head. Default: dict(type='CrossEntropyLoss', use_mask=True,\n            loss_weight=1.0).\n    \"\"\"\n\n    def __init__(self,\n                 num_fcs=3,\n                 coarse_pred_each_layer=True,\n                 conv_cfg=dict(type='Conv1d'),\n                 norm_cfg=None,\n                 act_cfg=dict(type='ReLU', inplace=False),\n                 **kwargs):\n        super(PointHead, self).__init__(\n            input_transform='multiple_select',\n            conv_cfg=conv_cfg,\n            norm_cfg=norm_cfg,\n            act_cfg=act_cfg,\n            **kwargs)\n\n        self.num_fcs = num_fcs\n        self.coarse_pred_each_layer = coarse_pred_each_layer\n\n        fc_in_channels = sum(self.in_channels) + self.num_classes\n        fc_channels = self.channels\n        self.fcs = nn.ModuleList()\n        for k in range(num_fcs):\n            fc = ConvModule(\n                fc_in_channels,\n                fc_channels,\n                kernel_size=1,\n                stride=1,\n                padding=0,\n                conv_cfg=conv_cfg,\n                norm_cfg=norm_cfg,\n                act_cfg=act_cfg)\n            self.fcs.append(fc)\n            fc_in_channels = fc_channels\n            fc_in_channels += self.num_classes if self.coarse_pred_each_layer \\\n                else 0\n        self.fc_seg = nn.Conv1d(\n            fc_in_channels,\n            self.num_classes,\n            kernel_size=1,\n            stride=1,\n            padding=0)\n        if self.dropout_ratio > 0:\n            self.dropout = nn.Dropout(self.dropout_ratio)\n        delattr(self, 'conv_seg')\n\n    def init_weights(self):\n        \"\"\"Initialize weights of classification layer.\"\"\"\n        normal_init(self.fc_seg, std=0.001)\n\n    def cls_seg(self, feat):\n        \"\"\"Classify each pixel with fc.\"\"\"\n        if self.dropout is not None:\n            feat = self.dropout(feat)\n        output = self.fc_seg(feat)\n        return output\n\n    def forward(self, fine_grained_point_feats, coarse_point_feats):\n        x = torch.cat([fine_grained_point_feats, coarse_point_feats], dim=1)\n        for fc in self.fcs:\n            x = fc(x)\n            if self.coarse_pred_each_layer:\n                x = torch.cat((x, coarse_point_feats), dim=1)\n        return self.cls_seg(x)\n\n    def _get_fine_grained_point_feats(self, x, points):\n        \"\"\"Sample from fine grained features.\n\n        Args:\n            x (list[Tensor]): Feature pyramid from by neck or backbone.\n            points (Tensor): Point coordinates, shape (batch_size,\n                num_points, 2).\n\n        Returns:\n            fine_grained_feats (Tensor): Sampled fine grained feature,\n                shape (batch_size, sum(channels of x), num_points).\n        \"\"\"\n\n        fine_grained_feats_list = [\n            point_sample(_, points, align_corners=self.align_corners)\n            for _ in x\n        ]\n        if len(fine_grained_feats_list) > 1:\n            fine_grained_feats = torch.cat(fine_grained_feats_list, dim=1)\n        else:\n            fine_grained_feats = fine_grained_feats_list[0]\n\n        return fine_grained_feats\n\n    def _get_coarse_point_feats(self, prev_output, points):\n        \"\"\"Sample from fine grained features.\n\n        Args:\n            prev_output (list[Tensor]): Prediction of previous decode head.\n            points (Tensor): Point coordinates, shape (batch_size,\n                num_points, 2).\n\n        Returns:\n            coarse_feats (Tensor): Sampled coarse feature, shape (batch_size,\n                num_classes, num_points).\n        \"\"\"\n\n        coarse_feats = point_sample(\n            prev_output, points, align_corners=self.align_corners)\n\n        return coarse_feats\n\n    def forward_train(self, inputs, prev_output, img_metas, gt_semantic_seg,\n                      train_cfg):\n        \"\"\"Forward function for training.\n        Args:\n            inputs (list[Tensor]): List of multi-level img features.\n            prev_output (Tensor): The output of previous decode head.\n            img_metas (list[dict]): List of image info dict where each dict\n                has: 'img_shape', 'scale_factor', 'flip', and may also contain\n                'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.\n                For details on the values of these keys see\n                `mmseg/datasets/pipelines/formatting.py:Collect`.\n            gt_semantic_seg (Tensor): Semantic segmentation masks\n                used if the architecture supports semantic segmentation task.\n            train_cfg (dict): The training config.\n\n        Returns:\n            dict[str, Tensor]: a dictionary of loss components\n        \"\"\"\n        x = self._transform_inputs(inputs)\n        with torch.no_grad():\n            points = self.get_points_train(\n                prev_output, calculate_uncertainty, cfg=train_cfg)\n        fine_grained_point_feats = self._get_fine_grained_point_feats(\n            x, points)\n        coarse_point_feats = self._get_coarse_point_feats(prev_output, points)\n        point_logits = self.forward(fine_grained_point_feats,\n                                    coarse_point_feats)\n        point_label = point_sample(\n            gt_semantic_seg.float(),\n            points,\n            mode='nearest',\n            align_corners=self.align_corners)\n        point_label = point_label.squeeze(1).long()\n\n        losses = self.losses(point_logits, point_label)\n\n        return losses\n\n    def forward_test(self, inputs, prev_output, img_metas, test_cfg):\n        \"\"\"Forward function for testing.\n\n        Args:\n            inputs (list[Tensor]): List of multi-level img features.\n            prev_output (Tensor): The output of previous decode head.\n            img_metas (list[dict]): List of image info dict where each dict\n                has: 'img_shape', 'scale_factor', 'flip', and may also contain\n                'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.\n                For details on the values of these keys see\n                `mmseg/datasets/pipelines/formatting.py:Collect`.\n            test_cfg (dict): The testing config.\n\n        Returns:\n            Tensor: Output segmentation map.\n        \"\"\"\n\n        x = self._transform_inputs(inputs)\n        refined_seg_logits = prev_output.clone()\n        for _ in range(test_cfg.subdivision_steps):\n            refined_seg_logits = resize(\n                refined_seg_logits,\n                scale_factor=test_cfg.scale_factor,\n                mode='bilinear',\n                align_corners=self.align_corners)\n            batch_size, channels, height, width = refined_seg_logits.shape\n            point_indices, points = self.get_points_test(\n                refined_seg_logits, calculate_uncertainty, cfg=test_cfg)\n            fine_grained_point_feats = self._get_fine_grained_point_feats(\n                x, points)\n            coarse_point_feats = self._get_coarse_point_feats(\n                prev_output, points)\n            point_logits = self.forward(fine_grained_point_feats,\n                                        coarse_point_feats)\n\n            point_indices = point_indices.unsqueeze(1).expand(-1, channels, -1)\n            refined_seg_logits = refined_seg_logits.reshape(\n                batch_size, channels, height * width)\n            refined_seg_logits = refined_seg_logits.scatter_(\n                2, point_indices, point_logits)\n            refined_seg_logits = refined_seg_logits.view(\n                batch_size, channels, height, width)\n\n        return refined_seg_logits\n\n    def losses(self, point_logits, point_label):\n        \"\"\"Compute segmentation loss.\"\"\"\n        loss = dict()\n        loss['loss_point'] = self.loss_decode(\n            point_logits, point_label, ignore_index=self.ignore_index)\n        loss['acc_point'] = accuracy(point_logits, point_label)\n        return loss\n\n    def get_points_train(self, seg_logits, uncertainty_func, cfg):\n        \"\"\"Sample points for training.\n\n        Sample points in [0, 1] x [0, 1] coordinate space based on their\n        uncertainty. The uncertainties are calculated for each point using\n        'uncertainty_func' function that takes point's logit prediction as\n        input.\n\n        Args:\n            seg_logits (Tensor): Semantic segmentation logits, shape (\n                batch_size, num_classes, height, width).\n            uncertainty_func (func): uncertainty calculation function.\n            cfg (dict): Training config of point head.\n\n        Returns:\n            point_coords (Tensor): A tensor of shape (batch_size, num_points,\n                2) that contains the coordinates of ``num_points`` sampled\n                points.\n        \"\"\"\n        num_points = cfg.num_points\n        oversample_ratio = cfg.oversample_ratio\n        importance_sample_ratio = cfg.importance_sample_ratio\n        assert oversample_ratio >= 1\n        assert 0 <= importance_sample_ratio <= 1\n        batch_size = seg_logits.shape[0]\n        num_sampled = int(num_points * oversample_ratio)\n        point_coords = torch.rand(\n            batch_size, num_sampled, 2, device=seg_logits.device)\n        point_logits = point_sample(seg_logits, point_coords)\n        # It is crucial to calculate uncertainty based on the sampled\n        # prediction value for the points. Calculating uncertainties of the\n        # coarse predictions first and sampling them for points leads to\n        # incorrect results.  To illustrate this: assume uncertainty func(\n        # logits)=-abs(logits), a sampled point between two coarse\n        # predictions with -1 and 1 logits has 0 logits, and therefore 0\n        # uncertainty value. However, if we calculate uncertainties for the\n        # coarse predictions first, both will have -1 uncertainty,\n        # and sampled point will get -1 uncertainty.\n        point_uncertainties = uncertainty_func(point_logits)\n        num_uncertain_points = int(importance_sample_ratio * num_points)\n        num_random_points = num_points - num_uncertain_points\n        idx = torch.topk(\n            point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1]\n        shift = num_sampled * torch.arange(\n            batch_size, dtype=torch.long, device=seg_logits.device)\n        idx += shift[:, None]\n        point_coords = point_coords.view(-1, 2)[idx.view(-1), :].view(\n            batch_size, num_uncertain_points, 2)\n        if num_random_points > 0:\n            rand_point_coords = torch.rand(\n                batch_size, num_random_points, 2, device=seg_logits.device)\n            point_coords = torch.cat((point_coords, rand_point_coords), dim=1)\n        return point_coords\n\n    def get_points_test(self, seg_logits, uncertainty_func, cfg):\n        \"\"\"Sample points for testing.\n\n        Find ``num_points`` most uncertain points from ``uncertainty_map``.\n\n        Args:\n            seg_logits (Tensor): A tensor of shape (batch_size, num_classes,\n                height, width) for class-specific or class-agnostic prediction.\n            uncertainty_func (func): uncertainty calculation function.\n            cfg (dict): Testing config of point head.\n\n        Returns:\n            point_indices (Tensor): A tensor of shape (batch_size, num_points)\n                that contains indices from [0, height x width) of the most\n                uncertain points.\n            point_coords (Tensor): A tensor of shape (batch_size, num_points,\n                2) that contains [0, 1] x [0, 1] normalized coordinates of the\n                most uncertain points from the ``height x width`` grid .\n        \"\"\"\n\n        num_points = cfg.subdivision_num_points\n        uncertainty_map = uncertainty_func(seg_logits)\n        batch_size, _, height, width = uncertainty_map.shape\n        h_step = 1.0 / height\n        w_step = 1.0 / width\n\n        uncertainty_map = uncertainty_map.view(batch_size, height * width)\n        num_points = min(height * width, num_points)\n        point_indices = uncertainty_map.topk(num_points, dim=1)[1]\n        point_coords = torch.zeros(\n            batch_size,\n            num_points,\n            2,\n            dtype=torch.float,\n            device=seg_logits.device)\n        point_coords[:, :, 0] = w_step / 2.0 + (point_indices %\n                                                width).float() * w_step\n        point_coords[:, :, 1] = h_step / 2.0 + (point_indices //\n                                                width).float() * h_step\n        return point_indices, point_coords\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/mmseg/models/decode_heads/psa_head.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom mmcv.cnn import ConvModule\n\nfrom mmseg.ops import resize\nfrom ..builder import HEADS\nfrom .decode_head import BaseDecodeHead\n\ntry:\n    from mmcv.ops import PSAMask\nexcept ModuleNotFoundError:\n    PSAMask = None\n\n\n@HEADS.register_module()\nclass PSAHead(BaseDecodeHead):\n    \"\"\"Point-wise Spatial Attention Network for Scene Parsing.\n\n    This head is the implementation of `PSANet\n    <https://hszhao.github.io/papers/eccv18_psanet.pdf>`_.\n\n    Args:\n        mask_size (tuple[int]): The PSA mask size. It usually equals input\n            size.\n        psa_type (str): The type of psa module. Options are 'collect',\n            'distribute', 'bi-direction'. Default: 'bi-direction'\n        compact (bool): Whether use compact map for 'collect' mode.\n            Default: True.\n        shrink_factor (int): The downsample factors of psa mask. Default: 2.\n        normalization_factor (float): The normalize factor of attention.\n        psa_softmax (bool): Whether use softmax for attention.\n    \"\"\"\n\n    def __init__(self,\n                 mask_size,\n                 psa_type='bi-direction',\n                 compact=False,\n                 shrink_factor=2,\n                 normalization_factor=1.0,\n                 psa_softmax=True,\n                 **kwargs):\n        if PSAMask is None:\n            raise RuntimeError('Please install mmcv-full for PSAMask ops')\n        super(PSAHead, self).__init__(**kwargs)\n        assert psa_type in ['collect', 'distribute', 'bi-direction']\n        self.psa_type = psa_type\n        self.compact = compact\n        self.shrink_factor = shrink_factor\n        self.mask_size = mask_size\n        mask_h, mask_w = mask_size\n        self.psa_softmax = psa_softmax\n        if normalization_factor is None:\n            normalization_factor = mask_h * mask_w\n        self.normalization_factor = normalization_factor\n\n        self.reduce = ConvModule(\n            self.in_channels,\n            self.channels,\n            kernel_size=1,\n            conv_cfg=self.conv_cfg,\n            norm_cfg=self.norm_cfg,\n            act_cfg=self.act_cfg)\n        self.attention = nn.Sequential(\n            ConvModule(\n                self.channels,\n                self.channels,\n                kernel_size=1,\n                conv_cfg=self.conv_cfg,\n                norm_cfg=self.norm_cfg,\n                act_cfg=self.act_cfg),\n            nn.Conv2d(\n                self.channels, mask_h * mask_w, kernel_size=1, bias=False))\n        if psa_type == 'bi-direction':\n            self.reduce_p = ConvModule(\n                self.in_channels,\n                self.channels,\n                kernel_size=1,\n                conv_cfg=self.conv_cfg,\n                norm_cfg=self.norm_cfg,\n                act_cfg=self.act_cfg)\n            self.attention_p = nn.Sequential(\n                ConvModule(\n                    self.channels,\n                    self.channels,\n                    kernel_size=1,\n                    conv_cfg=self.conv_cfg,\n                    norm_cfg=self.norm_cfg,\n                    act_cfg=self.act_cfg),\n                nn.Conv2d(\n                    self.channels, mask_h * mask_w, kernel_size=1, bias=False))\n            self.psamask_collect = PSAMask('collect', mask_size)\n            self.psamask_distribute = PSAMask('distribute', mask_size)\n        else:\n            self.psamask = PSAMask(psa_type, mask_size)\n        self.proj = ConvModule(\n            self.channels * (2 if psa_type == 'bi-direction' else 1),\n            self.in_channels,\n            kernel_size=1,\n            padding=1,\n            conv_cfg=self.conv_cfg,\n            norm_cfg=self.norm_cfg,\n            act_cfg=self.act_cfg)\n        self.bottleneck = ConvModule(\n            self.in_channels * 2,\n            self.channels,\n            kernel_size=3,\n            padding=1,\n            conv_cfg=self.conv_cfg,\n            norm_cfg=self.norm_cfg,\n            act_cfg=self.act_cfg)\n\n    def forward(self, inputs):\n        \"\"\"Forward function.\"\"\"\n        x = self._transform_inputs(inputs)\n        identity = x\n        align_corners = self.align_corners\n        if self.psa_type in ['collect', 'distribute']:\n            out = self.reduce(x)\n            n, c, h, w = out.size()\n            if self.shrink_factor != 1:\n                if h % self.shrink_factor and w % self.shrink_factor:\n                    h = (h - 1) // self.shrink_factor + 1\n                    w = (w - 1) // self.shrink_factor + 1\n                    align_corners = True\n                else:\n                    h = h // self.shrink_factor\n                    w = w // self.shrink_factor\n                    align_corners = False\n                out = resize(\n                    out,\n                    size=(h, w),\n                    mode='bilinear',\n                    align_corners=align_corners)\n            y = self.attention(out)\n            if self.compact:\n                if self.psa_type == 'collect':\n                    y = y.view(n, h * w,\n                               h * w).transpose(1, 2).view(n, h * w, h, w)\n            else:\n                y = self.psamask(y)\n            if self.psa_softmax:\n                y = F.softmax(y, dim=1)\n            out = torch.bmm(\n                out.view(n, c, h * w), y.view(n, h * w, h * w)).view(\n                    n, c, h, w) * (1.0 / self.normalization_factor)\n        else:\n            x_col = self.reduce(x)\n            x_dis = self.reduce_p(x)\n            n, c, h, w = x_col.size()\n            if self.shrink_factor != 1:\n                if h % self.shrink_factor and w % self.shrink_factor:\n                    h = (h - 1) // self.shrink_factor + 1\n                    w = (w - 1) // self.shrink_factor + 1\n                    align_corners = True\n                else:\n                    h = h // self.shrink_factor\n                    w = w // self.shrink_factor\n                    align_corners = False\n                x_col = resize(\n                    x_col,\n                    size=(h, w),\n                    mode='bilinear',\n                    align_corners=align_corners)\n                x_dis = resize(\n                    x_dis,\n                    size=(h, w),\n                    mode='bilinear',\n                    align_corners=align_corners)\n            y_col = self.attention(x_col)\n            y_dis = self.attention_p(x_dis)\n            if self.compact:\n                y_dis = y_dis.view(n, h * w,\n                                   h * w).transpose(1, 2).view(n, h * w, h, w)\n            else:\n                y_col = self.psamask_collect(y_col)\n                y_dis = self.psamask_distribute(y_dis)\n            if self.psa_softmax:\n                y_col = F.softmax(y_col, dim=1)\n                y_dis = F.softmax(y_dis, dim=1)\n            x_col = torch.bmm(\n                x_col.view(n, c, h * w), y_col.view(n, h * w, h * w)).view(\n                    n, c, h, w) * (1.0 / self.normalization_factor)\n            x_dis = torch.bmm(\n                x_dis.view(n, c, h * w), y_dis.view(n, h * w, h * w)).view(\n                    n, c, h, w) * (1.0 / self.normalization_factor)\n            out = torch.cat([x_col, x_dis], 1)\n        out = self.proj(out)\n        out = resize(\n            out,\n            size=identity.shape[2:],\n            mode='bilinear',\n            align_corners=align_corners)\n        out = self.bottleneck(torch.cat((identity, out), dim=1))\n        out = self.cls_seg(out)\n        return out\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/mmseg/models/decode_heads/psp_head.py",
    "content": "import torch\nimport torch.nn as nn\nfrom mmcv.cnn import ConvModule\n\nfrom mmseg.ops import resize\nfrom ..builder import HEADS\nfrom .decode_head import BaseDecodeHead\n\n\nclass PPM(nn.ModuleList):\n    \"\"\"Pooling Pyramid Module used in PSPNet.\n\n    Args:\n        pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid\n            Module.\n        in_channels (int): Input channels.\n        channels (int): Channels after modules, before conv_seg.\n        conv_cfg (dict|None): Config of conv layers.\n        norm_cfg (dict|None): Config of norm layers.\n        act_cfg (dict): Config of activation layers.\n        align_corners (bool): align_corners argument of F.interpolate.\n    \"\"\"\n\n    def __init__(self, pool_scales, in_channels, channels, conv_cfg, norm_cfg,\n                 act_cfg, align_corners):\n        super(PPM, self).__init__()\n        self.pool_scales = pool_scales\n        self.align_corners = align_corners\n        self.in_channels = in_channels\n        self.channels = channels\n        self.conv_cfg = conv_cfg\n        self.norm_cfg = norm_cfg\n        self.act_cfg = act_cfg\n        for pool_scale in pool_scales:\n            self.append(\n                nn.Sequential(\n                    nn.AdaptiveAvgPool2d(pool_scale),\n                    ConvModule(\n                        self.in_channels,\n                        self.channels,\n                        1,\n                        conv_cfg=self.conv_cfg,\n                        norm_cfg=self.norm_cfg,\n                        act_cfg=self.act_cfg)))\n\n    def forward(self, x):\n        \"\"\"Forward function.\"\"\"\n        ppm_outs = []\n        for ppm in self:\n            ppm_out = ppm(x)\n            upsampled_ppm_out = resize(\n                ppm_out,\n                size=x.size()[2:],\n                mode='bilinear',\n                align_corners=self.align_corners)\n            ppm_outs.append(upsampled_ppm_out)\n        return ppm_outs\n\n\n@HEADS.register_module()\nclass PSPHead(BaseDecodeHead):\n    \"\"\"Pyramid Scene Parsing Network.\n\n    This head is the implementation of\n    `PSPNet <https://arxiv.org/abs/1612.01105>`_.\n\n    Args:\n        pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid\n            Module. Default: (1, 2, 3, 6).\n    \"\"\"\n\n    def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs):\n        super(PSPHead, self).__init__(**kwargs)\n        assert isinstance(pool_scales, (list, tuple))\n        self.pool_scales = pool_scales\n        self.psp_modules = PPM(\n            self.pool_scales,\n            self.in_channels,\n            self.channels,\n            conv_cfg=self.conv_cfg,\n            norm_cfg=self.norm_cfg,\n            act_cfg=self.act_cfg,\n            align_corners=self.align_corners)\n        self.bottleneck = ConvModule(\n            self.in_channels + len(pool_scales) * self.channels,\n            self.channels,\n            3,\n            padding=1,\n            conv_cfg=self.conv_cfg,\n            norm_cfg=self.norm_cfg,\n            act_cfg=self.act_cfg)\n\n    def forward(self, inputs):\n        \"\"\"Forward function.\"\"\"\n        x = self._transform_inputs(inputs)\n        psp_outs = [x]\n        psp_outs.extend(self.psp_modules(x))\n        psp_outs = torch.cat(psp_outs, dim=1)\n        output = self.bottleneck(psp_outs)\n        output = self.cls_seg(output)\n        return output\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/mmseg/models/decode_heads/sep_aspp_head.py",
    "content": "import torch\nimport torch.nn as nn\nfrom mmcv.cnn import ConvModule, DepthwiseSeparableConvModule\n\nfrom mmseg.ops import resize\nfrom ..builder import HEADS\nfrom .aspp_head import ASPPHead, ASPPModule\n\n\nclass DepthwiseSeparableASPPModule(ASPPModule):\n    \"\"\"Atrous Spatial Pyramid Pooling (ASPP) Module with depthwise separable\n    conv.\"\"\"\n\n    def __init__(self, **kwargs):\n        super(DepthwiseSeparableASPPModule, self).__init__(**kwargs)\n        for i, dilation in enumerate(self.dilations):\n            if dilation > 1:\n                self[i] = DepthwiseSeparableConvModule(\n                    self.in_channels,\n                    self.channels,\n                    3,\n                    dilation=dilation,\n                    padding=dilation,\n                    norm_cfg=self.norm_cfg,\n                    act_cfg=self.act_cfg)\n\n\n@HEADS.register_module()\nclass DepthwiseSeparableASPPHead(ASPPHead):\n    \"\"\"Encoder-Decoder with Atrous Separable Convolution for Semantic Image\n    Segmentation.\n\n    This head is the implementation of `DeepLabV3+\n    <https://arxiv.org/abs/1802.02611>`_.\n\n    Args:\n        c1_in_channels (int): The input channels of c1 decoder. If is 0,\n            the no decoder will be used.\n        c1_channels (int): The intermediate channels of c1 decoder.\n    \"\"\"\n\n    def __init__(self, c1_in_channels, c1_channels, **kwargs):\n        super(DepthwiseSeparableASPPHead, self).__init__(**kwargs)\n        assert c1_in_channels >= 0\n        self.aspp_modules = DepthwiseSeparableASPPModule(\n            dilations=self.dilations,\n            in_channels=self.in_channels,\n            channels=self.channels,\n            conv_cfg=self.conv_cfg,\n            norm_cfg=self.norm_cfg,\n            act_cfg=self.act_cfg)\n        if c1_in_channels > 0:\n            self.c1_bottleneck = ConvModule(\n                c1_in_channels,\n                c1_channels,\n                1,\n                conv_cfg=self.conv_cfg,\n                norm_cfg=self.norm_cfg,\n                act_cfg=self.act_cfg)\n        else:\n            self.c1_bottleneck = None\n        self.sep_bottleneck = nn.Sequential(\n            DepthwiseSeparableConvModule(\n                self.channels + c1_channels,\n                self.channels,\n                3,\n                padding=1,\n                norm_cfg=self.norm_cfg,\n                act_cfg=self.act_cfg),\n            DepthwiseSeparableConvModule(\n                self.channels,\n                self.channels,\n                3,\n                padding=1,\n                norm_cfg=self.norm_cfg,\n                act_cfg=self.act_cfg))\n\n    def forward(self, inputs):\n        \"\"\"Forward function.\"\"\"\n        x = self._transform_inputs(inputs)\n        aspp_outs = [\n            resize(\n                self.image_pool(x),\n                size=x.size()[2:],\n                mode='bilinear',\n                align_corners=self.align_corners)\n        ]\n        aspp_outs.extend(self.aspp_modules(x))\n        aspp_outs = torch.cat(aspp_outs, dim=1)\n        output = self.bottleneck(aspp_outs)\n        if self.c1_bottleneck is not None:\n            c1_output = self.c1_bottleneck(inputs[0])\n            output = resize(\n                input=output,\n                size=c1_output.shape[2:],\n                mode='bilinear',\n                align_corners=self.align_corners)\n            output = torch.cat([output, c1_output], dim=1)\n        output = self.sep_bottleneck(output)\n        output = self.cls_seg(output)\n        return output\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/mmseg/models/decode_heads/sep_fcn_head.py",
    "content": "from mmcv.cnn import DepthwiseSeparableConvModule\n\nfrom ..builder import HEADS\nfrom .fcn_head import FCNHead\n\n\n@HEADS.register_module()\nclass DepthwiseSeparableFCNHead(FCNHead):\n    \"\"\"Depthwise-Separable Fully Convolutional Network for Semantic\n    Segmentation.\n\n    This head is implemented according to Fast-SCNN paper.\n    Args:\n        in_channels(int): Number of output channels of FFM.\n        channels(int): Number of middle-stage channels in the decode head.\n        concat_input(bool): Whether to concatenate original decode input into\n            the result of several consecutive convolution layers.\n            Default: True.\n        num_classes(int): Used to determine the dimension of\n            final prediction tensor.\n        in_index(int): Correspond with 'out_indices' in FastSCNN backbone.\n        norm_cfg (dict | None): Config of norm layers.\n        align_corners (bool): align_corners argument of F.interpolate.\n            Default: False.\n        loss_decode(dict): Config of loss type and some\n            relevant additional options.\n    \"\"\"\n\n    def __init__(self, **kwargs):\n        super(DepthwiseSeparableFCNHead, self).__init__(**kwargs)\n        self.convs[0] = DepthwiseSeparableConvModule(\n            self.in_channels,\n            self.channels,\n            kernel_size=self.kernel_size,\n            padding=self.kernel_size // 2,\n            norm_cfg=self.norm_cfg)\n        for i in range(1, self.num_convs):\n            self.convs[i] = DepthwiseSeparableConvModule(\n                self.channels,\n                self.channels,\n                kernel_size=self.kernel_size,\n                padding=self.kernel_size // 2,\n                norm_cfg=self.norm_cfg)\n\n        if self.concat_input:\n            self.conv_cat = DepthwiseSeparableConvModule(\n                self.in_channels + self.channels,\n                self.channels,\n                kernel_size=self.kernel_size,\n                padding=self.kernel_size // 2,\n                norm_cfg=self.norm_cfg)\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/mmseg/models/decode_heads/uper_head.py",
    "content": "import torch\nimport torch.nn as nn\nfrom mmcv.cnn import ConvModule\n\nfrom mmseg.ops import resize\nfrom ..builder import HEADS\nfrom .decode_head import BaseDecodeHead\nfrom .psp_head import PPM\n\n\n@HEADS.register_module()\nclass UPerHead(BaseDecodeHead):\n    \"\"\"Unified Perceptual Parsing for Scene Understanding.\n\n    This head is the implementation of `UPerNet\n    <https://arxiv.org/abs/1807.10221>`_.\n\n    Args:\n        pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid\n            Module applied on the last feature. Default: (1, 2, 3, 6).\n    \"\"\"\n\n    def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs):\n        super(UPerHead, self).__init__(\n            input_transform='multiple_select', **kwargs)\n        # PSP Module\n        self.psp_modules = PPM(\n            pool_scales,\n            self.in_channels[-1],\n            self.channels,\n            conv_cfg=self.conv_cfg,\n            norm_cfg=self.norm_cfg,\n            act_cfg=self.act_cfg,\n            align_corners=self.align_corners)\n        self.bottleneck = ConvModule(\n            self.in_channels[-1] + len(pool_scales) * self.channels,\n            self.channels,\n            3,\n            padding=1,\n            conv_cfg=self.conv_cfg,\n            norm_cfg=self.norm_cfg,\n            act_cfg=self.act_cfg)\n        # FPN Module\n        self.lateral_convs = nn.ModuleList()\n        self.fpn_convs = nn.ModuleList()\n        for in_channels in self.in_channels[:-1]:  # skip the top layer\n            l_conv = ConvModule(\n                in_channels,\n                self.channels,\n                1,\n                conv_cfg=self.conv_cfg,\n                norm_cfg=self.norm_cfg,\n                act_cfg=self.act_cfg,\n                inplace=False)\n            fpn_conv = ConvModule(\n                self.channels,\n                self.channels,\n                3,\n                padding=1,\n                conv_cfg=self.conv_cfg,\n                norm_cfg=self.norm_cfg,\n                act_cfg=self.act_cfg,\n                inplace=False)\n            self.lateral_convs.append(l_conv)\n            self.fpn_convs.append(fpn_conv)\n\n        self.fpn_bottleneck = ConvModule(\n            len(self.in_channels) * self.channels,\n            self.channels,\n            3,\n            padding=1,\n            conv_cfg=self.conv_cfg,\n            norm_cfg=self.norm_cfg,\n            act_cfg=self.act_cfg)\n\n    def psp_forward(self, inputs):\n        \"\"\"Forward function of PSP module.\"\"\"\n        x = inputs[-1]\n        psp_outs = [x]\n        psp_outs.extend(self.psp_modules(x))\n        psp_outs = torch.cat(psp_outs, dim=1)\n        output = self.bottleneck(psp_outs)\n\n        return output\n\n    def forward(self, inputs):\n        \"\"\"Forward function.\"\"\"\n\n        inputs = self._transform_inputs(inputs)\n\n        # build laterals\n        laterals = [\n            lateral_conv(inputs[i])\n            for i, lateral_conv in enumerate(self.lateral_convs)\n        ]\n\n        laterals.append(self.psp_forward(inputs))\n\n        # build top-down path\n        used_backbone_levels = len(laterals)\n        for i in range(used_backbone_levels - 1, 0, -1):\n            prev_shape = laterals[i - 1].shape[2:]\n            laterals[i - 1] += resize(\n                laterals[i],\n                size=prev_shape,\n                mode='bilinear',\n                align_corners=self.align_corners)\n\n        # build outputs\n        fpn_outs = [\n            self.fpn_convs[i](laterals[i])\n            for i in range(used_backbone_levels - 1)\n        ]\n        # append psp feature\n        fpn_outs.append(laterals[-1])\n\n        for i in range(used_backbone_levels - 1, 0, -1):\n            fpn_outs[i] = resize(\n                fpn_outs[i],\n                size=fpn_outs[0].shape[2:],\n                mode='bilinear',\n                align_corners=self.align_corners)\n        fpn_outs = torch.cat(fpn_outs, dim=1)\n        output = self.fpn_bottleneck(fpn_outs)\n        output = self.cls_seg(output)\n        return output\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/mmseg/models/losses/__init__.py",
    "content": "from .accuracy import Accuracy, accuracy\nfrom .cross_entropy_loss import (CrossEntropyLoss, binary_cross_entropy,\n                                 cross_entropy, mask_cross_entropy)\nfrom .lovasz_loss import LovaszLoss\nfrom .utils import reduce_loss, weight_reduce_loss, weighted_loss\n\n__all__ = [\n    'accuracy', 'Accuracy', 'cross_entropy', 'binary_cross_entropy',\n    'mask_cross_entropy', 'CrossEntropyLoss', 'reduce_loss',\n    'weight_reduce_loss', 'weighted_loss', 'LovaszLoss'\n]\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/mmseg/models/losses/accuracy.py",
    "content": "import torch.nn as nn\n\n\ndef accuracy(pred, target, topk=1, thresh=None):\n    \"\"\"Calculate accuracy according to the prediction and target.\n\n    Args:\n        pred (torch.Tensor): The model prediction, shape (N, num_class, ...)\n        target (torch.Tensor): The target of each prediction, shape (N, , ...)\n        topk (int | tuple[int], optional): If the predictions in ``topk``\n            matches the target, the predictions will be regarded as\n            correct ones. Defaults to 1.\n        thresh (float, optional): If not None, predictions with scores under\n            this threshold are considered incorrect. Default to None.\n\n    Returns:\n        float | tuple[float]: If the input ``topk`` is a single integer,\n            the function will return a single float as accuracy. If\n            ``topk`` is a tuple containing multiple integers, the\n            function will return a tuple containing accuracies of\n            each ``topk`` number.\n    \"\"\"\n    assert isinstance(topk, (int, tuple))\n    if isinstance(topk, int):\n        topk = (topk, )\n        return_single = True\n    else:\n        return_single = False\n\n    maxk = max(topk)\n    if pred.size(0) == 0:\n        accu = [pred.new_tensor(0.) for i in range(len(topk))]\n        return accu[0] if return_single else accu\n    assert pred.ndim == target.ndim + 1\n    assert pred.size(0) == target.size(0)\n    assert maxk <= pred.size(1), \\\n        f'maxk {maxk} exceeds pred dimension {pred.size(1)}'\n    pred_value, pred_label = pred.topk(maxk, dim=1)\n    # transpose to shape (maxk, N, ...)\n    pred_label = pred_label.transpose(0, 1)\n    correct = pred_label.eq(target.unsqueeze(0).expand_as(pred_label))\n    if thresh is not None:\n        # Only prediction values larger than thresh are counted as correct\n        correct = correct & (pred_value > thresh).t()\n    res = []\n    for k in topk:\n        correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)\n        res.append(correct_k.mul_(100.0 / target.numel()))\n    return res[0] if return_single else res\n\n\nclass Accuracy(nn.Module):\n    \"\"\"Accuracy calculation module.\"\"\"\n\n    def __init__(self, topk=(1, ), thresh=None):\n        \"\"\"Module to calculate the accuracy.\n\n        Args:\n            topk (tuple, optional): The criterion used to calculate the\n                accuracy. Defaults to (1,).\n            thresh (float, optional): If not None, predictions with scores\n                under this threshold are considered incorrect. Default to None.\n        \"\"\"\n        super().__init__()\n        self.topk = topk\n        self.thresh = thresh\n\n    def forward(self, pred, target):\n        \"\"\"Forward function to calculate accuracy.\n\n        Args:\n            pred (torch.Tensor): Prediction of models.\n            target (torch.Tensor): Target for each prediction.\n\n        Returns:\n            tuple[float]: The accuracies under different topk criterions.\n        \"\"\"\n        return accuracy(pred, target, self.topk, self.thresh)\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/mmseg/models/losses/cross_entropy_loss.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom ..builder import LOSSES\nfrom .utils import weight_reduce_loss\n\n\ndef cross_entropy(pred,\n                  label,\n                  weight=None,\n                  class_weight=None,\n                  reduction='mean',\n                  avg_factor=None,\n                  ignore_index=-100):\n    \"\"\"The wrapper function for :func:`F.cross_entropy`\"\"\"\n    # class_weight is a manual rescaling weight given to each class.\n    # If given, has to be a Tensor of size C element-wise losses\n    loss = F.cross_entropy(\n        pred,\n        label,\n        weight=class_weight,\n        reduction='none',\n        ignore_index=ignore_index)\n\n    # apply weights and do the reduction\n    if weight is not None:\n        weight = weight.float()\n    loss = weight_reduce_loss(\n        loss, weight=weight, reduction=reduction, avg_factor=avg_factor)\n\n    return loss\n\n\ndef _expand_onehot_labels(labels, label_weights, target_shape, ignore_index):\n    \"\"\"Expand onehot labels to match the size of prediction.\"\"\"\n    bin_labels = labels.new_zeros(target_shape)\n    valid_mask = (labels >= 0) & (labels != ignore_index)\n    inds = torch.nonzero(valid_mask, as_tuple=True)\n\n    if inds[0].numel() > 0:\n        if labels.dim() == 3:\n            bin_labels[inds[0], labels[valid_mask], inds[1], inds[2]] = 1\n        else:\n            bin_labels[inds[0], labels[valid_mask]] = 1\n\n    valid_mask = valid_mask.unsqueeze(1).expand(target_shape).float()\n    if label_weights is None:\n        bin_label_weights = valid_mask\n    else:\n        bin_label_weights = label_weights.unsqueeze(1).expand(target_shape)\n        bin_label_weights *= valid_mask\n\n    return bin_labels, bin_label_weights\n\n\ndef binary_cross_entropy(pred,\n                         label,\n                         weight=None,\n                         reduction='mean',\n                         avg_factor=None,\n                         class_weight=None,\n                         ignore_index=255):\n    \"\"\"Calculate the binary CrossEntropy loss.\n\n    Args:\n        pred (torch.Tensor): The prediction with shape (N, 1).\n        label (torch.Tensor): The learning label of the prediction.\n        weight (torch.Tensor, optional): Sample-wise loss weight.\n        reduction (str, optional): The method used to reduce the loss.\n            Options are \"none\", \"mean\" and \"sum\".\n        avg_factor (int, optional): Average factor that is used to average\n            the loss. Defaults to None.\n        class_weight (list[float], optional): The weight for each class.\n        ignore_index (int | None): The label index to be ignored. Default: 255\n\n    Returns:\n        torch.Tensor: The calculated loss\n    \"\"\"\n    if pred.dim() != label.dim():\n        assert (pred.dim() == 2 and label.dim() == 1) or (\n                pred.dim() == 4 and label.dim() == 3), \\\n            'Only pred shape [N, C], label shape [N] or pred shape [N, C, ' \\\n            'H, W], label shape [N, H, W] are supported'\n        label, weight = _expand_onehot_labels(label, weight, pred.shape,\n                                              ignore_index)\n\n    # weighted element-wise losses\n    if weight is not None:\n        weight = weight.float()\n    loss = F.binary_cross_entropy_with_logits(\n        pred, label.float(), pos_weight=class_weight, reduction='none')\n    # do the reduction for the weighted loss\n    loss = weight_reduce_loss(\n        loss, weight, reduction=reduction, avg_factor=avg_factor)\n\n    return loss\n\n\ndef mask_cross_entropy(pred,\n                       target,\n                       label,\n                       reduction='mean',\n                       avg_factor=None,\n                       class_weight=None,\n                       ignore_index=None):\n    \"\"\"Calculate the CrossEntropy loss for masks.\n\n    Args:\n        pred (torch.Tensor): The prediction with shape (N, C), C is the number\n            of classes.\n        target (torch.Tensor): The learning label of the prediction.\n        label (torch.Tensor): ``label`` indicates the class label of the mask'\n            corresponding object. This will be used to select the mask in the\n            of the class which the object belongs to when the mask prediction\n            if not class-agnostic.\n        reduction (str, optional): The method used to reduce the loss.\n            Options are \"none\", \"mean\" and \"sum\".\n        avg_factor (int, optional): Average factor that is used to average\n            the loss. Defaults to None.\n        class_weight (list[float], optional): The weight for each class.\n        ignore_index (None): Placeholder, to be consistent with other loss.\n            Default: None.\n\n    Returns:\n        torch.Tensor: The calculated loss\n    \"\"\"\n    assert ignore_index is None, 'BCE loss does not support ignore_index'\n    # TODO: handle these two reserved arguments\n    assert reduction == 'mean' and avg_factor is None\n    num_rois = pred.size()[0]\n    inds = torch.arange(0, num_rois, dtype=torch.long, device=pred.device)\n    pred_slice = pred[inds, label].squeeze(1)\n    return F.binary_cross_entropy_with_logits(\n        pred_slice, target, weight=class_weight, reduction='mean')[None]\n\n\n@LOSSES.register_module()\nclass CrossEntropyLoss(nn.Module):\n    \"\"\"CrossEntropyLoss.\n\n    Args:\n        use_sigmoid (bool, optional): Whether the prediction uses sigmoid\n            of softmax. Defaults to False.\n        use_mask (bool, optional): Whether to use mask cross entropy loss.\n            Defaults to False.\n        reduction (str, optional): . Defaults to 'mean'.\n            Options are \"none\", \"mean\" and \"sum\".\n        class_weight (list[float], optional): Weight of each class.\n            Defaults to None.\n        loss_weight (float, optional): Weight of the loss. Defaults to 1.0.\n    \"\"\"\n\n    def __init__(self,\n                 use_sigmoid=False,\n                 use_mask=False,\n                 reduction='mean',\n                 class_weight=None,\n                 loss_weight=1.0):\n        super(CrossEntropyLoss, self).__init__()\n        assert (use_sigmoid is False) or (use_mask is False)\n        self.use_sigmoid = use_sigmoid\n        self.use_mask = use_mask\n        self.reduction = reduction\n        self.loss_weight = loss_weight\n        self.class_weight = class_weight\n\n        if self.use_sigmoid:\n            self.cls_criterion = binary_cross_entropy\n        elif self.use_mask:\n            self.cls_criterion = mask_cross_entropy\n        else:\n            self.cls_criterion = cross_entropy\n\n    def forward(self,\n                cls_score,\n                label,\n                weight=None,\n                avg_factor=None,\n                reduction_override=None,\n                **kwargs):\n        \"\"\"Forward function.\"\"\"\n        assert reduction_override in (None, 'none', 'mean', 'sum')\n        reduction = (\n            reduction_override if reduction_override else self.reduction)\n        if self.class_weight is not None:\n            class_weight = cls_score.new_tensor(self.class_weight)\n        else:\n            class_weight = None\n        loss_cls = self.loss_weight * self.cls_criterion(\n            cls_score,\n            label,\n            weight,\n            class_weight=class_weight,\n            reduction=reduction,\n            avg_factor=avg_factor,\n            **kwargs)\n        return loss_cls\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/mmseg/models/losses/lovasz_loss.py",
    "content": "\"\"\"Modified from https://github.com/bermanmaxim/LovaszSoftmax/blob/master/pytor\nch/lovasz_losses.py Lovasz-Softmax and Jaccard hinge loss in PyTorch Maxim\nBerman 2018 ESAT-PSI KU Leuven (MIT License)\"\"\"\n\nimport mmcv\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom ..builder import LOSSES\nfrom .utils import weight_reduce_loss\n\n\ndef lovasz_grad(gt_sorted):\n    \"\"\"Computes gradient of the Lovasz extension w.r.t sorted errors.\n\n    See Alg. 1 in paper.\n    \"\"\"\n    p = len(gt_sorted)\n    gts = gt_sorted.sum()\n    intersection = gts - gt_sorted.float().cumsum(0)\n    union = gts + (1 - gt_sorted).float().cumsum(0)\n    jaccard = 1. - intersection / union\n    if p > 1:  # cover 1-pixel case\n        jaccard[1:p] = jaccard[1:p] - jaccard[0:-1]\n    return jaccard\n\n\ndef flatten_binary_logits(logits, labels, ignore_index=None):\n    \"\"\"Flattens predictions in the batch (binary case) Remove labels equal to\n    'ignore_index'.\"\"\"\n    logits = logits.view(-1)\n    labels = labels.view(-1)\n    if ignore_index is None:\n        return logits, labels\n    valid = (labels != ignore_index)\n    vlogits = logits[valid]\n    vlabels = labels[valid]\n    return vlogits, vlabels\n\n\ndef flatten_probs(probs, labels, ignore_index=None):\n    \"\"\"Flattens predictions in the batch.\"\"\"\n    if probs.dim() == 3:\n        # assumes output of a sigmoid layer\n        B, H, W = probs.size()\n        probs = probs.view(B, 1, H, W)\n    B, C, H, W = probs.size()\n    probs = probs.permute(0, 2, 3, 1).contiguous().view(-1, C)  # B*H*W, C=P,C\n    labels = labels.view(-1)\n    if ignore_index is None:\n        return probs, labels\n    valid = (labels != ignore_index)\n    vprobs = probs[valid.nonzero().squeeze()]\n    vlabels = labels[valid]\n    return vprobs, vlabels\n\n\ndef lovasz_hinge_flat(logits, labels):\n    \"\"\"Binary Lovasz hinge loss.\n\n    Args:\n        logits (torch.Tensor): [P], logits at each prediction\n            (between -infty and +infty).\n        labels (torch.Tensor): [P], binary ground truth labels (0 or 1).\n\n    Returns:\n        torch.Tensor: The calculated loss.\n    \"\"\"\n    if len(labels) == 0:\n        # only void pixels, the gradients should be 0\n        return logits.sum() * 0.\n    signs = 2. * labels.float() - 1.\n    errors = (1. - logits * signs)\n    errors_sorted, perm = torch.sort(errors, dim=0, descending=True)\n    perm = perm.data\n    gt_sorted = labels[perm]\n    grad = lovasz_grad(gt_sorted)\n    loss = torch.dot(F.relu(errors_sorted), grad)\n    return loss\n\n\ndef lovasz_hinge(logits,\n                 labels,\n                 classes='present',\n                 per_image=False,\n                 class_weight=None,\n                 reduction='mean',\n                 avg_factor=None,\n                 ignore_index=255):\n    \"\"\"Binary Lovasz hinge loss.\n\n    Args:\n        logits (torch.Tensor): [B, H, W], logits at each pixel\n            (between -infty and +infty).\n        labels (torch.Tensor): [B, H, W], binary ground truth masks (0 or 1).\n        classes (str | list[int], optional): Placeholder, to be consistent with\n            other loss. Default: None.\n        per_image (bool, optional): If per_image is True, compute the loss per\n            image instead of per batch. Default: False.\n        class_weight (list[float], optional): Placeholder, to be consistent\n            with other loss. Default: None.\n        reduction (str, optional): The method used to reduce the loss. Options\n            are \"none\", \"mean\" and \"sum\". This parameter only works when\n            per_image is True. Default: 'mean'.\n        avg_factor (int, optional): Average factor that is used to average\n            the loss. This parameter only works when per_image is True.\n            Default: None.\n        ignore_index (int | None): The label index to be ignored. Default: 255.\n\n    Returns:\n        torch.Tensor: The calculated loss.\n    \"\"\"\n    if per_image:\n        loss = [\n            lovasz_hinge_flat(*flatten_binary_logits(\n                logit.unsqueeze(0), label.unsqueeze(0), ignore_index))\n            for logit, label in zip(logits, labels)\n        ]\n        loss = weight_reduce_loss(\n            torch.stack(loss), None, reduction, avg_factor)\n    else:\n        loss = lovasz_hinge_flat(\n            *flatten_binary_logits(logits, labels, ignore_index))\n    return loss\n\n\ndef lovasz_softmax_flat(probs, labels, classes='present', class_weight=None):\n    \"\"\"Multi-class Lovasz-Softmax loss.\n\n    Args:\n        probs (torch.Tensor): [P, C], class probabilities at each prediction\n            (between 0 and 1).\n        labels (torch.Tensor): [P], ground truth labels (between 0 and C - 1).\n        classes (str | list[int], optional): Classes choosed to calculate loss.\n            'all' for all classes, 'present' for classes present in labels, or\n            a list of classes to average. Default: 'present'.\n        class_weight (list[float], optional): The weight for each class.\n            Default: None.\n\n    Returns:\n        torch.Tensor: The calculated loss.\n    \"\"\"\n    if probs.numel() == 0:\n        # only void pixels, the gradients should be 0\n        return probs * 0.\n    C = probs.size(1)\n    losses = []\n    class_to_sum = list(range(C)) if classes in ['all', 'present'] else classes\n    for c in class_to_sum:\n        fg = (labels == c).float()  # foreground for class c\n        if (classes == 'present' and fg.sum() == 0):\n            continue\n        if C == 1:\n            if len(classes) > 1:\n                raise ValueError('Sigmoid output possible only with 1 class')\n            class_pred = probs[:, 0]\n        else:\n            class_pred = probs[:, c]\n        errors = (fg - class_pred).abs()\n        errors_sorted, perm = torch.sort(errors, 0, descending=True)\n        perm = perm.data\n        fg_sorted = fg[perm]\n        loss = torch.dot(errors_sorted, lovasz_grad(fg_sorted))\n        if class_weight is not None:\n            loss *= class_weight[c]\n        losses.append(loss)\n    return torch.stack(losses).mean()\n\n\ndef lovasz_softmax(probs,\n                   labels,\n                   classes='present',\n                   per_image=False,\n                   class_weight=None,\n                   reduction='mean',\n                   avg_factor=None,\n                   ignore_index=255):\n    \"\"\"Multi-class Lovasz-Softmax loss.\n\n    Args:\n        probs (torch.Tensor): [B, C, H, W], class probabilities at each\n            prediction (between 0 and 1).\n        labels (torch.Tensor): [B, H, W], ground truth labels (between 0 and\n            C - 1).\n        classes (str | list[int], optional): Classes choosed to calculate loss.\n            'all' for all classes, 'present' for classes present in labels, or\n            a list of classes to average. Default: 'present'.\n        per_image (bool, optional): If per_image is True, compute the loss per\n            image instead of per batch. Default: False.\n        class_weight (list[float], optional): The weight for each class.\n            Default: None.\n        reduction (str, optional): The method used to reduce the loss. Options\n            are \"none\", \"mean\" and \"sum\". This parameter only works when\n            per_image is True. Default: 'mean'.\n        avg_factor (int, optional): Average factor that is used to average\n            the loss. This parameter only works when per_image is True.\n            Default: None.\n        ignore_index (int | None): The label index to be ignored. Default: 255.\n\n    Returns:\n        torch.Tensor: The calculated loss.\n    \"\"\"\n\n    if per_image:\n        loss = [\n            lovasz_softmax_flat(\n                *flatten_probs(\n                    prob.unsqueeze(0), label.unsqueeze(0), ignore_index),\n                classes=classes,\n                class_weight=class_weight)\n            for prob, label in zip(probs, labels)\n        ]\n        loss = weight_reduce_loss(\n            torch.stack(loss), None, reduction, avg_factor)\n    else:\n        loss = lovasz_softmax_flat(\n            *flatten_probs(probs, labels, ignore_index),\n            classes=classes,\n            class_weight=class_weight)\n    return loss\n\n\n@LOSSES.register_module()\nclass LovaszLoss(nn.Module):\n    \"\"\"LovaszLoss.\n\n    This loss is proposed in `The Lovasz-Softmax loss: A tractable surrogate\n    for the optimization of the intersection-over-union measure in neural\n    networks <https://arxiv.org/abs/1705.08790>`_.\n\n    Args:\n        loss_type (str, optional): Binary or multi-class loss.\n            Default: 'multi_class'. Options are \"binary\" and \"multi_class\".\n        classes (str | list[int], optional): Classes choosed to calculate loss.\n            'all' for all classes, 'present' for classes present in labels, or\n            a list of classes to average. Default: 'present'.\n        per_image (bool, optional): If per_image is True, compute the loss per\n            image instead of per batch. Default: False.\n        reduction (str, optional): The method used to reduce the loss. Options\n            are \"none\", \"mean\" and \"sum\". This parameter only works when\n            per_image is True. Default: 'mean'.\n        class_weight (list[float], optional): The weight for each class.\n            Default: None.\n        loss_weight (float, optional): Weight of the loss. Defaults to 1.0.\n    \"\"\"\n\n    def __init__(self,\n                 loss_type='multi_class',\n                 classes='present',\n                 per_image=False,\n                 reduction='mean',\n                 class_weight=None,\n                 loss_weight=1.0):\n        super(LovaszLoss, self).__init__()\n        assert loss_type in ('binary', 'multi_class'), \"loss_type should be \\\n                                                    'binary' or 'multi_class'.\"\n\n        if loss_type == 'binary':\n            self.cls_criterion = lovasz_hinge\n        else:\n            self.cls_criterion = lovasz_softmax\n        assert classes in ('all', 'present') or mmcv.is_list_of(classes, int)\n        if not per_image:\n            assert reduction == 'none', \"reduction should be 'none' when \\\n                                                        per_image is False.\"\n\n        self.classes = classes\n        self.per_image = per_image\n        self.reduction = reduction\n        self.loss_weight = loss_weight\n        self.class_weight = class_weight\n\n    def forward(self,\n                cls_score,\n                label,\n                weight=None,\n                avg_factor=None,\n                reduction_override=None,\n                **kwargs):\n        \"\"\"Forward function.\"\"\"\n        assert reduction_override in (None, 'none', 'mean', 'sum')\n        reduction = (\n            reduction_override if reduction_override else self.reduction)\n        if self.class_weight is not None:\n            class_weight = cls_score.new_tensor(self.class_weight)\n        else:\n            class_weight = None\n\n        # if multi-class loss, transform logits to probs\n        if self.cls_criterion == lovasz_softmax:\n            cls_score = F.softmax(cls_score, dim=1)\n\n        loss_cls = self.loss_weight * self.cls_criterion(\n            cls_score,\n            label,\n            self.classes,\n            self.per_image,\n            class_weight=class_weight,\n            reduction=reduction,\n            avg_factor=avg_factor,\n            **kwargs)\n        return loss_cls\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/mmseg/models/losses/utils.py",
    "content": "import functools\n\nimport torch.nn.functional as F\n\n\ndef reduce_loss(loss, reduction):\n    \"\"\"Reduce loss as specified.\n\n    Args:\n        loss (Tensor): Elementwise loss tensor.\n        reduction (str): Options are \"none\", \"mean\" and \"sum\".\n\n    Return:\n        Tensor: Reduced loss tensor.\n    \"\"\"\n    reduction_enum = F._Reduction.get_enum(reduction)\n    # none: 0, elementwise_mean:1, sum: 2\n    if reduction_enum == 0:\n        return loss\n    elif reduction_enum == 1:\n        return loss.mean()\n    elif reduction_enum == 2:\n        return loss.sum()\n\n\ndef weight_reduce_loss(loss, weight=None, reduction='mean', avg_factor=None):\n    \"\"\"Apply element-wise weight and reduce loss.\n\n    Args:\n        loss (Tensor): Element-wise loss.\n        weight (Tensor): Element-wise weights.\n        reduction (str): Same as built-in losses of PyTorch.\n        avg_factor (float): Avarage factor when computing the mean of losses.\n\n    Returns:\n        Tensor: Processed loss values.\n    \"\"\"\n    # if weight is specified, apply element-wise weight\n    if weight is not None:\n        assert weight.dim() == loss.dim()\n        if weight.dim() > 1:\n            assert weight.size(1) == 1 or weight.size(1) == loss.size(1)\n        loss = loss * weight\n\n    # if avg_factor is not specified, just reduce the loss\n    if avg_factor is None:\n        loss = reduce_loss(loss, reduction)\n    else:\n        # if reduction is mean, then average the loss by avg_factor\n        if reduction == 'mean':\n            loss = loss.sum() / avg_factor\n        # if reduction is 'none', then do nothing, otherwise raise an error\n        elif reduction != 'none':\n            raise ValueError('avg_factor can not be used with reduction=\"sum\"')\n    return loss\n\n\ndef weighted_loss(loss_func):\n    \"\"\"Create a weighted version of a given loss function.\n\n    To use this decorator, the loss function must have the signature like\n    `loss_func(pred, target, **kwargs)`. The function only needs to compute\n    element-wise loss without any reduction. This decorator will add weight\n    and reduction arguments to the function. The decorated function will have\n    the signature like `loss_func(pred, target, weight=None, reduction='mean',\n    avg_factor=None, **kwargs)`.\n\n    :Example:\n\n    >>> import torch\n    >>> @weighted_loss\n    >>> def l1_loss(pred, target):\n    >>>     return (pred - target).abs()\n\n    >>> pred = torch.Tensor([0, 2, 3])\n    >>> target = torch.Tensor([1, 1, 1])\n    >>> weight = torch.Tensor([1, 0, 1])\n\n    >>> l1_loss(pred, target)\n    tensor(1.3333)\n    >>> l1_loss(pred, target, weight)\n    tensor(1.)\n    >>> l1_loss(pred, target, reduction='none')\n    tensor([1., 1., 2.])\n    >>> l1_loss(pred, target, weight, avg_factor=2)\n    tensor(1.5000)\n    \"\"\"\n\n    @functools.wraps(loss_func)\n    def wrapper(pred,\n                target,\n                weight=None,\n                reduction='mean',\n                avg_factor=None,\n                **kwargs):\n        # get element-wise loss\n        loss = loss_func(pred, target, **kwargs)\n        loss = weight_reduce_loss(loss, weight, reduction, avg_factor)\n        return loss\n\n    return wrapper\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/mmseg/models/necks/__init__.py",
    "content": "from .fpn import FPN\n\n__all__ = ['FPN']\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/mmseg/models/necks/fpn.py",
    "content": "import torch.nn as nn\nimport torch.nn.functional as F\nfrom mmcv.cnn import ConvModule, xavier_init\n\nfrom ..builder import NECKS\n\n\n@NECKS.register_module()\nclass FPN(nn.Module):\n    \"\"\"Feature Pyramid Network.\n\n    This is an implementation of - Feature Pyramid Networks for Object\n    Detection (https://arxiv.org/abs/1612.03144)\n\n    Args:\n        in_channels (List[int]): Number of input channels per scale.\n        out_channels (int): Number of output channels (used at each scale)\n        num_outs (int): Number of output scales.\n        start_level (int): Index of the start input backbone level used to\n            build the feature pyramid. Default: 0.\n        end_level (int): Index of the end input backbone level (exclusive) to\n            build the feature pyramid. Default: -1, which means the last level.\n        add_extra_convs (bool | str): If bool, it decides whether to add conv\n            layers on top of the original feature maps. Default to False.\n            If True, its actual mode is specified by `extra_convs_on_inputs`.\n            If str, it specifies the source feature map of the extra convs.\n            Only the following options are allowed\n\n            - 'on_input': Last feat map of neck inputs (i.e. backbone feature).\n            - 'on_lateral':  Last feature map after lateral convs.\n            - 'on_output': The last output feature map after fpn convs.\n        extra_convs_on_inputs (bool, deprecated): Whether to apply extra convs\n            on the original feature from the backbone. If True,\n            it is equivalent to `add_extra_convs='on_input'`. If False, it is\n            equivalent to set `add_extra_convs='on_output'`. Default to True.\n        relu_before_extra_convs (bool): Whether to apply relu before the extra\n            conv. Default: False.\n        no_norm_on_lateral (bool): Whether to apply norm on lateral.\n            Default: False.\n        conv_cfg (dict): Config dict for convolution layer. Default: None.\n        norm_cfg (dict): Config dict for normalization layer. Default: None.\n        act_cfg (str): Config dict for activation layer in ConvModule.\n            Default: None.\n        upsample_cfg (dict): Config dict for interpolate layer.\n            Default: `dict(mode='nearest')`\n\n    Example:\n        >>> import torch\n        >>> in_channels = [2, 3, 5, 7]\n        >>> scales = [340, 170, 84, 43]\n        >>> inputs = [torch.rand(1, c, s, s)\n        ...           for c, s in zip(in_channels, scales)]\n        >>> self = FPN(in_channels, 11, len(in_channels)).eval()\n        >>> outputs = self.forward(inputs)\n        >>> for i in range(len(outputs)):\n        ...     print(f'outputs[{i}].shape = {outputs[i].shape}')\n        outputs[0].shape = torch.Size([1, 11, 340, 340])\n        outputs[1].shape = torch.Size([1, 11, 170, 170])\n        outputs[2].shape = torch.Size([1, 11, 84, 84])\n        outputs[3].shape = torch.Size([1, 11, 43, 43])\n    \"\"\"\n\n    def __init__(self,\n                 in_channels,\n                 out_channels,\n                 num_outs,\n                 start_level=0,\n                 end_level=-1,\n                 add_extra_convs=False,\n                 extra_convs_on_inputs=False,\n                 relu_before_extra_convs=False,\n                 no_norm_on_lateral=False,\n                 conv_cfg=None,\n                 norm_cfg=None,\n                 act_cfg=None,\n                 upsample_cfg=dict(mode='nearest')):\n        super(FPN, self).__init__()\n        assert isinstance(in_channels, list)\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.num_ins = len(in_channels)\n        self.num_outs = num_outs\n        self.relu_before_extra_convs = relu_before_extra_convs\n        self.no_norm_on_lateral = no_norm_on_lateral\n        self.fp16_enabled = False\n        self.upsample_cfg = upsample_cfg.copy()\n\n        if end_level == -1:\n            self.backbone_end_level = self.num_ins\n            assert num_outs >= self.num_ins - start_level\n        else:\n            # if end_level < inputs, no extra level is allowed\n            self.backbone_end_level = end_level\n            assert end_level <= len(in_channels)\n            assert num_outs == end_level - start_level\n        self.start_level = start_level\n        self.end_level = end_level\n        self.add_extra_convs = add_extra_convs\n        assert isinstance(add_extra_convs, (str, bool))\n        if isinstance(add_extra_convs, str):\n            # Extra_convs_source choices: 'on_input', 'on_lateral', 'on_output'\n            assert add_extra_convs in ('on_input', 'on_lateral', 'on_output')\n        elif add_extra_convs:  # True\n            if extra_convs_on_inputs:\n                # For compatibility with previous release\n                # TODO: deprecate `extra_convs_on_inputs`\n                self.add_extra_convs = 'on_input'\n            else:\n                self.add_extra_convs = 'on_output'\n\n        self.lateral_convs = nn.ModuleList()\n        self.fpn_convs = nn.ModuleList()\n\n        for i in range(self.start_level, self.backbone_end_level):\n            l_conv = ConvModule(\n                in_channels[i],\n                out_channels,\n                1,\n                conv_cfg=conv_cfg,\n                norm_cfg=norm_cfg if not self.no_norm_on_lateral else None,\n                act_cfg=act_cfg,\n                inplace=False)\n            fpn_conv = ConvModule(\n                out_channels,\n                out_channels,\n                3,\n                padding=1,\n                conv_cfg=conv_cfg,\n                norm_cfg=norm_cfg,\n                act_cfg=act_cfg,\n                inplace=False)\n\n            self.lateral_convs.append(l_conv)\n            self.fpn_convs.append(fpn_conv)\n\n        # add extra conv layers (e.g., RetinaNet)\n        extra_levels = num_outs - self.backbone_end_level + self.start_level\n        if self.add_extra_convs and extra_levels >= 1:\n            for i in range(extra_levels):\n                if i == 0 and self.add_extra_convs == 'on_input':\n                    in_channels = self.in_channels[self.backbone_end_level - 1]\n                else:\n                    in_channels = out_channels\n                extra_fpn_conv = ConvModule(\n                    in_channels,\n                    out_channels,\n                    3,\n                    stride=2,\n                    padding=1,\n                    conv_cfg=conv_cfg,\n                    norm_cfg=norm_cfg,\n                    act_cfg=act_cfg,\n                    inplace=False)\n                self.fpn_convs.append(extra_fpn_conv)\n\n    # default init_weights for conv(msra) and norm in ConvModule\n    def init_weights(self):\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                xavier_init(m, distribution='uniform')\n\n    def forward(self, inputs):\n        assert len(inputs) == len(self.in_channels)\n\n        # build laterals\n        laterals = [\n            lateral_conv(inputs[i + self.start_level])\n            for i, lateral_conv in enumerate(self.lateral_convs)\n        ]\n\n        # build top-down path\n        used_backbone_levels = len(laterals)\n        for i in range(used_backbone_levels - 1, 0, -1):\n            # In some cases, fixing `scale factor` (e.g. 2) is preferred, but\n            #  it cannot co-exist with `size` in `F.interpolate`.\n            if 'scale_factor' in self.upsample_cfg:\n                laterals[i - 1] += F.interpolate(laterals[i],\n                                                 **self.upsample_cfg)\n            else:\n                prev_shape = laterals[i - 1].shape[2:]\n                laterals[i - 1] += F.interpolate(\n                    laterals[i], size=prev_shape, **self.upsample_cfg)\n\n        # build outputs\n        # part 1: from original levels\n        outs = [\n            self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels)\n        ]\n        # part 2: add extra levels\n        if self.num_outs > len(outs):\n            # use max pool to get more levels on top of outputs\n            # (e.g., Faster R-CNN, Mask R-CNN)\n            if not self.add_extra_convs:\n                for i in range(self.num_outs - used_backbone_levels):\n                    outs.append(F.max_pool2d(outs[-1], 1, stride=2))\n            # add conv layers on top of original feature maps (RetinaNet)\n            else:\n                if self.add_extra_convs == 'on_input':\n                    extra_source = inputs[self.backbone_end_level - 1]\n                elif self.add_extra_convs == 'on_lateral':\n                    extra_source = laterals[-1]\n                elif self.add_extra_convs == 'on_output':\n                    extra_source = outs[-1]\n                else:\n                    raise NotImplementedError\n                outs.append(self.fpn_convs[used_backbone_levels](extra_source))\n                for i in range(used_backbone_levels + 1, self.num_outs):\n                    if self.relu_before_extra_convs:\n                        outs.append(self.fpn_convs[i](F.relu(outs[-1])))\n                    else:\n                        outs.append(self.fpn_convs[i](outs[-1]))\n        return tuple(outs)\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/mmseg/models/segmentors/__init__.py",
    "content": "from .cascade_encoder_decoder import CascadeEncoderDecoder\nfrom .encoder_decoder import EncoderDecoder\n\n__all__ = ['EncoderDecoder', 'CascadeEncoderDecoder']\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/mmseg/models/segmentors/base.py",
    "content": "import logging\nimport warnings\nfrom abc import ABCMeta, abstractmethod\nfrom collections import OrderedDict\n\nimport mmcv\nimport numpy as np\nimport torch\nimport torch.distributed as dist\nimport torch.nn as nn\nfrom mmcv.runner import auto_fp16\n\n\nclass BaseSegmentor(nn.Module):\n    \"\"\"Base class for segmentors.\"\"\"\n\n    __metaclass__ = ABCMeta\n\n    def __init__(self):\n        super(BaseSegmentor, self).__init__()\n        self.fp16_enabled = False\n\n    @property\n    def with_neck(self):\n        \"\"\"bool: whether the segmentor has neck\"\"\"\n        return hasattr(self, 'neck') and self.neck is not None\n\n    @property\n    def with_auxiliary_head(self):\n        \"\"\"bool: whether the segmentor has auxiliary head\"\"\"\n        return hasattr(self,\n                       'auxiliary_head') and self.auxiliary_head is not None\n\n    @property\n    def with_decode_head(self):\n        \"\"\"bool: whether the segmentor has decode head\"\"\"\n        return hasattr(self, 'decode_head') and self.decode_head is not None\n\n    @abstractmethod\n    def extract_feat(self, imgs):\n        \"\"\"Placeholder for extract features from images.\"\"\"\n        pass\n\n    @abstractmethod\n    def encode_decode(self, img, img_metas):\n        \"\"\"Placeholder for encode images with backbone and decode into a\n        semantic segmentation map of the same size as input.\"\"\"\n        pass\n\n    @abstractmethod\n    def forward_train(self, imgs, img_metas, **kwargs):\n        \"\"\"Placeholder for Forward function for training.\"\"\"\n        pass\n\n    @abstractmethod\n    def simple_test(self, img, img_meta, **kwargs):\n        \"\"\"Placeholder for single image test.\"\"\"\n        pass\n\n    @abstractmethod\n    def aug_test(self, imgs, img_metas, **kwargs):\n        \"\"\"Placeholder for augmentation test.\"\"\"\n        pass\n\n    def init_weights(self, pretrained=None):\n        \"\"\"Initialize the weights in segmentor.\n\n        Args:\n            pretrained (str, optional): Path to pre-trained weights.\n                Defaults to None.\n        \"\"\"\n        if pretrained is not None:\n            logger = logging.getLogger()\n            logger.info(f'load model from: {pretrained}')\n\n    def forward_test(self, imgs, img_metas, **kwargs):\n        \"\"\"\n        Args:\n            imgs (List[Tensor]): the outer list indicates test-time\n                augmentations and inner Tensor should have a shape NxCxHxW,\n                which contains all images in the batch.\n            img_metas (List[List[dict]]): the outer list indicates test-time\n                augs (multiscale, flip, etc.) and the inner list indicates\n                images in a batch.\n        \"\"\"\n        for var, name in [(imgs, 'imgs'), (img_metas, 'img_metas')]:\n            if not isinstance(var, list):\n                raise TypeError(f'{name} must be a list, but got '\n                                f'{type(var)}')\n\n        num_augs = len(imgs)\n        if num_augs != len(img_metas):\n            raise ValueError(f'num of augmentations ({len(imgs)}) != '\n                             f'num of image meta ({len(img_metas)})')\n        # all images in the same aug batch all of the same ori_shape and pad\n        # shape\n        for img_meta in img_metas:\n            ori_shapes = [_['ori_shape'] for _ in img_meta]\n            assert all(shape == ori_shapes[0] for shape in ori_shapes)\n            img_shapes = [_['img_shape'] for _ in img_meta]\n            assert all(shape == img_shapes[0] for shape in img_shapes)\n            pad_shapes = [_['pad_shape'] for _ in img_meta]\n            assert all(shape == pad_shapes[0] for shape in pad_shapes)\n\n        if num_augs == 1:\n            return self.simple_test(imgs[0], img_metas[0], **kwargs)\n        else:\n            return self.aug_test(imgs, img_metas, **kwargs)\n\n    @auto_fp16(apply_to=('img', ))\n    def forward(self, img, img_metas, return_loss=True, **kwargs):\n        \"\"\"Calls either :func:`forward_train` or :func:`forward_test` depending\n        on whether ``return_loss`` is ``True``.\n\n        Note this setting will change the expected inputs. When\n        ``return_loss=True``, img and img_meta are single-nested (i.e. Tensor\n        and List[dict]), and when ``resturn_loss=False``, img and img_meta\n        should be double nested (i.e.  List[Tensor], List[List[dict]]), with\n        the outer list indicating test time augmentations.\n        \"\"\"\n        if return_loss:\n            return self.forward_train(img, img_metas, **kwargs)\n        else:\n            return self.forward_test(img, img_metas, **kwargs)\n\n    def train_step(self, data_batch, optimizer, **kwargs):\n        \"\"\"The iteration step during training.\n\n        This method defines an iteration step during training, except for the\n        back propagation and optimizer updating, which are done in an optimizer\n        hook. Note that in some complicated cases or models, the whole process\n        including back propagation and optimizer updating is also defined in\n        this method, such as GAN.\n\n        Args:\n            data (dict): The output of dataloader.\n            optimizer (:obj:`torch.optim.Optimizer` | dict): The optimizer of\n                runner is passed to ``train_step()``. This argument is unused\n                and reserved.\n\n        Returns:\n            dict: It should contain at least 3 keys: ``loss``, ``log_vars``,\n                ``num_samples``.\n                ``loss`` is a tensor for back propagation, which can be a\n                weighted sum of multiple losses.\n                ``log_vars`` contains all the variables to be sent to the\n                logger.\n                ``num_samples`` indicates the batch size (when the model is\n                DDP, it means the batch size on each GPU), which is used for\n                averaging the logs.\n        \"\"\"\n        losses = self(**data_batch)\n        loss, log_vars = self._parse_losses(losses)\n\n        outputs = dict(\n            loss=loss,\n            log_vars=log_vars,\n            num_samples=len(data_batch['img'].data))\n\n        return outputs\n\n    def val_step(self, data_batch, **kwargs):\n        \"\"\"The iteration step during validation.\n\n        This method shares the same signature as :func:`train_step`, but used\n        during val epochs. Note that the evaluation after training epochs is\n        not implemented with this method, but an evaluation hook.\n        \"\"\"\n        output = self(**data_batch, **kwargs)\n        return output\n\n    @staticmethod\n    def _parse_losses(losses):\n        \"\"\"Parse the raw outputs (losses) of the network.\n\n        Args:\n            losses (dict): Raw output of the network, which usually contain\n                losses and other necessary information.\n\n        Returns:\n            tuple[Tensor, dict]: (loss, log_vars), loss is the loss tensor\n                which may be a weighted sum of all losses, log_vars contains\n                all the variables to be sent to the logger.\n        \"\"\"\n        log_vars = OrderedDict()\n        for loss_name, loss_value in losses.items():\n            if isinstance(loss_value, torch.Tensor):\n                log_vars[loss_name] = loss_value.mean()\n            elif isinstance(loss_value, list):\n                log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value)\n            else:\n                raise TypeError(\n                    f'{loss_name} is not a tensor or list of tensors')\n\n        loss = sum(_value for _key, _value in log_vars.items()\n                   if 'loss' in _key)\n\n        log_vars['loss'] = loss\n        for loss_name, loss_value in log_vars.items():\n            # reduce loss when distributed training\n            if dist.is_available() and dist.is_initialized():\n                loss_value = loss_value.data.clone()\n                dist.all_reduce(loss_value.div_(dist.get_world_size()))\n            log_vars[loss_name] = loss_value.item()\n\n        return loss, log_vars\n\n    def show_result(self,\n                    img,\n                    result,\n                    palette=None,\n                    win_name='',\n                    show=False,\n                    wait_time=0,\n                    out_file=None):\n        \"\"\"Draw `result` over `img`.\n\n        Args:\n            img (str or Tensor): The image to be displayed.\n            result (Tensor): The semantic segmentation results to draw over\n                `img`.\n            palette (list[list[int]]] | np.ndarray | None): The palette of\n                segmentation map. If None is given, random palette will be\n                generated. Default: None\n            win_name (str): The window name.\n            wait_time (int): Value of waitKey param.\n                Default: 0.\n            show (bool): Whether to show the image.\n                Default: False.\n            out_file (str or None): The filename to write the image.\n                Default: None.\n\n        Returns:\n            img (Tensor): Only if not `show` or `out_file`\n        \"\"\"\n        img = mmcv.imread(img)\n        img = img.copy()\n        seg = result[0]\n        if palette is None:\n            if self.PALETTE is None:\n                palette = np.random.randint(\n                    0, 255, size=(len(self.CLASSES), 3))\n            else:\n                palette = self.PALETTE\n        palette = np.array(palette)\n        assert palette.shape[0] == len(self.CLASSES)\n        assert palette.shape[1] == 3\n        assert len(palette.shape) == 2\n        color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)\n        for label, color in enumerate(palette):\n            color_seg[seg == label, :] = color\n        # convert to BGR\n        color_seg = color_seg[..., ::-1]\n\n        img = img * 0.5 + color_seg * 0.5\n        img = img.astype(np.uint8)\n        # if out_file specified, do not show image in window\n        if out_file is not None:\n            show = False\n\n        if show:\n            mmcv.imshow(img, win_name, wait_time)\n        if out_file is not None:\n            mmcv.imwrite(img, out_file)\n\n        if not (show or out_file):\n            warnings.warn('show==False and out_file is not specified, only '\n                          'result image will be returned')\n            return img\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/mmseg/models/segmentors/cascade_encoder_decoder.py",
    "content": "from torch import nn\n\nfrom mmseg.core import add_prefix\nfrom mmseg.ops import resize\nfrom .. import builder\nfrom ..builder import SEGMENTORS\nfrom .encoder_decoder import EncoderDecoder\n\n\n@SEGMENTORS.register_module()\nclass CascadeEncoderDecoder(EncoderDecoder):\n    \"\"\"Cascade Encoder Decoder segmentors.\n\n    CascadeEncoderDecoder almost the same as EncoderDecoder, while decoders of\n    CascadeEncoderDecoder are cascaded. The output of previous decoder_head\n    will be the input of next decoder_head.\n    \"\"\"\n\n    def __init__(self,\n                 num_stages,\n                 backbone,\n                 decode_head,\n                 neck=None,\n                 auxiliary_head=None,\n                 train_cfg=None,\n                 test_cfg=None,\n                 pretrained=None):\n        self.num_stages = num_stages\n        super(CascadeEncoderDecoder, self).__init__(\n            backbone=backbone,\n            decode_head=decode_head,\n            neck=neck,\n            auxiliary_head=auxiliary_head,\n            train_cfg=train_cfg,\n            test_cfg=test_cfg,\n            pretrained=pretrained)\n\n    def _init_decode_head(self, decode_head):\n        \"\"\"Initialize ``decode_head``\"\"\"\n        assert isinstance(decode_head, list)\n        assert len(decode_head) == self.num_stages\n        self.decode_head = nn.ModuleList()\n        for i in range(self.num_stages):\n            self.decode_head.append(builder.build_head(decode_head[i]))\n        self.align_corners = self.decode_head[-1].align_corners\n        self.num_classes = self.decode_head[-1].num_classes\n\n    def init_weights(self, pretrained=None):\n        \"\"\"Initialize the weights in backbone and heads.\n\n        Args:\n            pretrained (str, optional): Path to pre-trained weights.\n                Defaults to None.\n        \"\"\"\n        self.backbone.init_weights(pretrained=pretrained)\n        for i in range(self.num_stages):\n            self.decode_head[i].init_weights()\n        if self.with_auxiliary_head:\n            if isinstance(self.auxiliary_head, nn.ModuleList):\n                for aux_head in self.auxiliary_head:\n                    aux_head.init_weights()\n            else:\n                self.auxiliary_head.init_weights()\n\n    def encode_decode(self, img, img_metas):\n        \"\"\"Encode images with backbone and decode into a semantic segmentation\n        map of the same size as input.\"\"\"\n        x = self.extract_feat(img)\n        out = self.decode_head[0].forward_test(x, img_metas, self.test_cfg)\n        for i in range(1, self.num_stages):\n            out = self.decode_head[i].forward_test(x, out, img_metas,\n                                                   self.test_cfg)\n        out = resize(\n            input=out,\n            size=img.shape[2:],\n            mode='bilinear',\n            align_corners=self.align_corners)\n        return out\n\n    def _decode_head_forward_train(self, x, img_metas, gt_semantic_seg):\n        \"\"\"Run forward function and calculate loss for decode head in\n        training.\"\"\"\n        losses = dict()\n\n        loss_decode = self.decode_head[0].forward_train(\n            x, img_metas, gt_semantic_seg, self.train_cfg)\n\n        losses.update(add_prefix(loss_decode, 'decode_0'))\n\n        for i in range(1, self.num_stages):\n            # forward test again, maybe unnecessary for most methods.\n            prev_outputs = self.decode_head[i - 1].forward_test(\n                x, img_metas, self.test_cfg)\n            loss_decode = self.decode_head[i].forward_train(\n                x, prev_outputs, img_metas, gt_semantic_seg, self.train_cfg)\n            losses.update(add_prefix(loss_decode, f'decode_{i}'))\n\n        return losses\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/mmseg/models/segmentors/encoder_decoder.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom mmseg.core import add_prefix\nfrom mmseg.ops import resize\nfrom .. import builder\nfrom ..builder import SEGMENTORS\nfrom .base import BaseSegmentor\n\n\n@SEGMENTORS.register_module()\nclass EncoderDecoder(BaseSegmentor):\n    \"\"\"Encoder Decoder segmentors.\n\n    EncoderDecoder typically consists of backbone, decode_head, auxiliary_head.\n    Note that auxiliary_head is only used for deep supervision during training,\n    which could be dumped during inference.\n    \"\"\"\n\n    def __init__(self,\n                 backbone,\n                 decode_head,\n                 neck=None,\n                 auxiliary_head=None,\n                 train_cfg=None,\n                 test_cfg=None,\n                 pretrained=None):\n        super(EncoderDecoder, self).__init__()\n        self.backbone = builder.build_backbone(backbone)\n        if neck is not None:\n            self.neck = builder.build_neck(neck)\n        self._init_decode_head(decode_head)\n        self._init_auxiliary_head(auxiliary_head)\n\n        self.train_cfg = train_cfg\n        self.test_cfg = test_cfg\n\n        self.init_weights(pretrained=pretrained)\n\n        assert self.with_decode_head\n\n    def _init_decode_head(self, decode_head):\n        \"\"\"Initialize ``decode_head``\"\"\"\n        self.decode_head = builder.build_head(decode_head)\n        self.align_corners = self.decode_head.align_corners\n        self.num_classes = self.decode_head.num_classes\n\n    def _init_auxiliary_head(self, auxiliary_head):\n        \"\"\"Initialize ``auxiliary_head``\"\"\"\n        if auxiliary_head is not None:\n            if isinstance(auxiliary_head, list):\n                self.auxiliary_head = nn.ModuleList()\n                for head_cfg in auxiliary_head:\n                    self.auxiliary_head.append(builder.build_head(head_cfg))\n            else:\n                self.auxiliary_head = builder.build_head(auxiliary_head)\n\n    def init_weights(self, pretrained=None):\n        \"\"\"Initialize the weights in backbone and heads.\n\n        Args:\n            pretrained (str, optional): Path to pre-trained weights.\n                Defaults to None.\n        \"\"\"\n\n        super(EncoderDecoder, self).init_weights(pretrained)\n        self.backbone.init_weights(pretrained=pretrained)\n        self.decode_head.init_weights()\n        if self.with_auxiliary_head:\n            if isinstance(self.auxiliary_head, nn.ModuleList):\n                for aux_head in self.auxiliary_head:\n                    aux_head.init_weights()\n            else:\n                self.auxiliary_head.init_weights()\n\n    def extract_feat(self, img):\n        \"\"\"Extract features from images.\"\"\"\n        x = self.backbone(img)\n        if self.with_neck:\n            x = self.neck(x)\n        return x\n\n    def encode_decode(self, img, img_metas):\n        \"\"\"Encode images with backbone and decode into a semantic segmentation\n        map of the same size as input.\"\"\"\n        x = self.extract_feat(img)\n        out = self._decode_head_forward_test(x, img_metas)\n        out = resize(\n            input=out,\n            size=img.shape[2:],\n            mode='bilinear',\n            align_corners=self.align_corners)\n        return out\n\n    def _decode_head_forward_train(self, x, img_metas, gt_semantic_seg):\n        \"\"\"Run forward function and calculate loss for decode head in\n        training.\"\"\"\n        losses = dict()\n        loss_decode = self.decode_head.forward_train(x, img_metas,\n                                                     gt_semantic_seg,\n                                                     self.train_cfg)\n\n        losses.update(add_prefix(loss_decode, 'decode'))\n        return losses\n\n    def _decode_head_forward_test(self, x, img_metas):\n        \"\"\"Run forward function and calculate loss for decode head in\n        inference.\"\"\"\n        seg_logits = self.decode_head.forward_test(x, img_metas, self.test_cfg)\n        return seg_logits\n\n    def _auxiliary_head_forward_train(self, x, img_metas, gt_semantic_seg):\n        \"\"\"Run forward function and calculate loss for auxiliary head in\n        training.\"\"\"\n        losses = dict()\n        if isinstance(self.auxiliary_head, nn.ModuleList):\n            for idx, aux_head in enumerate(self.auxiliary_head):\n                loss_aux = aux_head.forward_train(x, img_metas,\n                                                  gt_semantic_seg,\n                                                  self.train_cfg)\n                losses.update(add_prefix(loss_aux, f'aux_{idx}'))\n        else:\n            loss_aux = self.auxiliary_head.forward_train(\n                x, img_metas, gt_semantic_seg, self.train_cfg)\n            losses.update(add_prefix(loss_aux, 'aux'))\n\n        return losses\n\n    def forward_dummy(self, img):\n        \"\"\"Dummy forward function.\"\"\"\n        seg_logit = self.encode_decode(img, None)\n\n        return seg_logit\n\n    def forward_train(self, img, img_metas, gt_semantic_seg):\n        \"\"\"Forward function for training.\n\n        Args:\n            img (Tensor): Input images.\n            img_metas (list[dict]): List of image info dict where each dict\n                has: 'img_shape', 'scale_factor', 'flip', and may also contain\n                'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.\n                For details on the values of these keys see\n                `mmseg/datasets/pipelines/formatting.py:Collect`.\n            gt_semantic_seg (Tensor): Semantic segmentation masks\n                used if the architecture supports semantic segmentation task.\n\n        Returns:\n            dict[str, Tensor]: a dictionary of loss components\n        \"\"\"\n\n        x = self.extract_feat(img)\n\n        losses = dict()\n\n        loss_decode = self._decode_head_forward_train(x, img_metas,\n                                                      gt_semantic_seg)\n        losses.update(loss_decode)\n\n        if self.with_auxiliary_head:\n            loss_aux = self._auxiliary_head_forward_train(\n                x, img_metas, gt_semantic_seg)\n            losses.update(loss_aux)\n\n        return losses\n\n    # TODO refactor\n    def slide_inference(self, img, img_meta, rescale):\n        \"\"\"Inference by sliding-window with overlap.\n\n        If h_crop > h_img or w_crop > w_img, the small patch will be used to\n        decode without padding.\n        \"\"\"\n\n        h_stride, w_stride = self.test_cfg.stride\n        h_crop, w_crop = self.test_cfg.crop_size\n        batch_size, _, h_img, w_img = img.size()\n        num_classes = self.num_classes\n        h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1\n        w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1\n        preds = img.new_zeros((batch_size, num_classes, h_img, w_img))\n        count_mat = img.new_zeros((batch_size, 1, h_img, w_img))\n        for h_idx in range(h_grids):\n            for w_idx in range(w_grids):\n                y1 = h_idx * h_stride\n                x1 = w_idx * w_stride\n                y2 = min(y1 + h_crop, h_img)\n                x2 = min(x1 + w_crop, w_img)\n                y1 = max(y2 - h_crop, 0)\n                x1 = max(x2 - w_crop, 0)\n                crop_img = img[:, :, y1:y2, x1:x2]\n                crop_seg_logit = self.encode_decode(crop_img, img_meta)\n                preds += F.pad(crop_seg_logit,\n                               (int(x1), int(preds.shape[3] - x2), int(y1),\n                                int(preds.shape[2] - y2)))\n\n                count_mat[:, :, y1:y2, x1:x2] += 1\n        assert (count_mat == 0).sum() == 0\n        if torch.onnx.is_in_onnx_export():\n            # cast count_mat to constant while exporting to ONNX\n            count_mat = torch.from_numpy(\n                count_mat.cpu().detach().numpy()).to(device=img.device)\n        preds = preds / count_mat\n        if rescale:\n            preds = resize(\n                preds,\n                size=img_meta[0]['ori_shape'][:2],\n                mode='bilinear',\n                align_corners=self.align_corners,\n                warning=False)\n        return preds\n\n    def whole_inference(self, img, img_meta, rescale):\n        \"\"\"Inference with full image.\"\"\"\n\n        seg_logit = self.encode_decode(img, img_meta)\n        if rescale:\n            seg_logit = resize(\n                seg_logit,\n                size=img_meta[0]['ori_shape'][:2],\n                mode='bilinear',\n                align_corners=self.align_corners,\n                warning=False)\n\n        return seg_logit\n\n    def inference(self, img, img_meta, rescale):\n        \"\"\"Inference with slide/whole style.\n\n        Args:\n            img (Tensor): The input image of shape (N, 3, H, W).\n            img_meta (dict): Image info dict where each dict has: 'img_shape',\n                'scale_factor', 'flip', and may also contain\n                'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.\n                For details on the values of these keys see\n                `mmseg/datasets/pipelines/formatting.py:Collect`.\n            rescale (bool): Whether rescale back to original shape.\n\n        Returns:\n            Tensor: The output segmentation map.\n        \"\"\"\n\n        assert self.test_cfg.mode in ['slide', 'whole']\n        ori_shape = img_meta[0]['ori_shape']\n        assert all(_['ori_shape'] == ori_shape for _ in img_meta)\n        if self.test_cfg.mode == 'slide':\n            seg_logit = self.slide_inference(img, img_meta, rescale)\n        else:\n            seg_logit = self.whole_inference(img, img_meta, rescale)\n        output = F.softmax(seg_logit, dim=1)\n        flip = img_meta[0]['flip']\n        if flip:\n            flip_direction = img_meta[0]['flip_direction']\n            assert flip_direction in ['horizontal', 'vertical']\n            if flip_direction == 'horizontal':\n                output = output.flip(dims=(3, ))\n            elif flip_direction == 'vertical':\n                output = output.flip(dims=(2, ))\n\n        return output\n\n    def simple_test(self, img, img_meta, rescale=True):\n        \"\"\"Simple test with single image.\"\"\"\n        seg_logit = self.inference(img, img_meta, rescale)\n        seg_pred = seg_logit.argmax(dim=1)\n        if torch.onnx.is_in_onnx_export():\n            # our inference backend only support 4D output\n            seg_pred = seg_pred.unsqueeze(0)\n            return seg_pred\n        seg_pred = seg_pred.cpu().numpy()\n        # unravel batch dim\n        seg_pred = list(seg_pred)\n        return seg_pred\n\n    def aug_test(self, imgs, img_metas, rescale=True):\n        \"\"\"Test with augmentations.\n\n        Only rescale=True is supported.\n        \"\"\"\n        # aug_test rescale all imgs back to ori_shape for now\n        assert rescale\n        # to save memory, we get augmented seg logit inplace\n        seg_logit = self.inference(imgs[0], img_metas[0], rescale)\n        for i in range(1, len(imgs)):\n            cur_seg_logit = self.inference(imgs[i], img_metas[i], rescale)\n            seg_logit += cur_seg_logit\n        seg_logit /= len(imgs)\n        seg_pred = seg_logit.argmax(dim=1)\n        seg_pred = seg_pred.cpu().numpy()\n        # unravel batch dim\n        seg_pred = list(seg_pred)\n        return seg_pred\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/mmseg/models/utils/__init__.py",
    "content": "from .inverted_residual import InvertedResidual, InvertedResidualV3\nfrom .make_divisible import make_divisible\nfrom .res_layer import ResLayer\nfrom .self_attention_block import SelfAttentionBlock\nfrom .up_conv_block import UpConvBlock\n\n__all__ = [\n    'ResLayer', 'SelfAttentionBlock', 'make_divisible', 'InvertedResidual',\n    'UpConvBlock', 'InvertedResidualV3'\n]\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/mmseg/models/utils/inverted_residual.py",
    "content": "from mmcv.cnn import ConvModule\nfrom torch import nn as nn\nfrom torch.utils import checkpoint as cp\n\nfrom .se_layer import SELayer\n\n\nclass InvertedResidual(nn.Module):\n    \"\"\"InvertedResidual block for MobileNetV2.\n\n    Args:\n        in_channels (int): The input channels of the InvertedResidual block.\n        out_channels (int): The output channels of the InvertedResidual block.\n        stride (int): Stride of the middle (first) 3x3 convolution.\n        expand_ratio (int): Adjusts number of channels of the hidden layer\n            in InvertedResidual by this amount.\n        dilation (int): Dilation rate of depthwise conv. Default: 1\n        conv_cfg (dict): Config dict for convolution layer.\n            Default: None, which means using conv2d.\n        norm_cfg (dict): Config dict for normalization layer.\n            Default: dict(type='BN').\n        act_cfg (dict): Config dict for activation layer.\n            Default: dict(type='ReLU6').\n        with_cp (bool): Use checkpoint or not. Using checkpoint will save some\n            memory while slowing down the training speed. Default: False.\n\n    Returns:\n        Tensor: The output tensor.\n    \"\"\"\n\n    def __init__(self,\n                 in_channels,\n                 out_channels,\n                 stride,\n                 expand_ratio,\n                 dilation=1,\n                 conv_cfg=None,\n                 norm_cfg=dict(type='BN'),\n                 act_cfg=dict(type='ReLU6'),\n                 with_cp=False):\n        super(InvertedResidual, self).__init__()\n        self.stride = stride\n        assert stride in [1, 2], f'stride must in [1, 2]. ' \\\n            f'But received {stride}.'\n        self.with_cp = with_cp\n        self.use_res_connect = self.stride == 1 and in_channels == out_channels\n        hidden_dim = int(round(in_channels * expand_ratio))\n\n        layers = []\n        if expand_ratio != 1:\n            layers.append(\n                ConvModule(\n                    in_channels=in_channels,\n                    out_channels=hidden_dim,\n                    kernel_size=1,\n                    conv_cfg=conv_cfg,\n                    norm_cfg=norm_cfg,\n                    act_cfg=act_cfg))\n        layers.extend([\n            ConvModule(\n                in_channels=hidden_dim,\n                out_channels=hidden_dim,\n                kernel_size=3,\n                stride=stride,\n                padding=dilation,\n                dilation=dilation,\n                groups=hidden_dim,\n                conv_cfg=conv_cfg,\n                norm_cfg=norm_cfg,\n                act_cfg=act_cfg),\n            ConvModule(\n                in_channels=hidden_dim,\n                out_channels=out_channels,\n                kernel_size=1,\n                conv_cfg=conv_cfg,\n                norm_cfg=norm_cfg,\n                act_cfg=None)\n        ])\n        self.conv = nn.Sequential(*layers)\n\n    def forward(self, x):\n\n        def _inner_forward(x):\n            if self.use_res_connect:\n                return x + self.conv(x)\n            else:\n                return self.conv(x)\n\n        if self.with_cp and x.requires_grad:\n            out = cp.checkpoint(_inner_forward, x)\n        else:\n            out = _inner_forward(x)\n\n        return out\n\n\nclass InvertedResidualV3(nn.Module):\n    \"\"\"Inverted Residual Block for MobileNetV3.\n\n    Args:\n        in_channels (int): The input channels of this Module.\n        out_channels (int): The output channels of this Module.\n        mid_channels (int): The input channels of the depthwise convolution.\n        kernel_size (int): The kernal size of the depthwise convolution.\n            Default: 3.\n        stride (int): The stride of the depthwise convolution. Default: 1.\n        se_cfg (dict): Config dict for se layer. Defaul: None, which means no\n            se layer.\n        with_expand_conv (bool): Use expand conv or not. If set False,\n            mid_channels must be the same with in_channels. Default: True.\n        conv_cfg (dict): Config dict for convolution layer. Default: None,\n            which means using conv2d.\n        norm_cfg (dict): Config dict for normalization layer.\n            Default: dict(type='BN').\n        act_cfg (dict): Config dict for activation layer.\n            Default: dict(type='ReLU').\n        with_cp (bool): Use checkpoint or not. Using checkpoint will save some\n            memory while slowing down the training speed. Default: False.\n\n    Returns:\n        Tensor: The output tensor.\n    \"\"\"\n\n    def __init__(self,\n                 in_channels,\n                 out_channels,\n                 mid_channels,\n                 kernel_size=3,\n                 stride=1,\n                 se_cfg=None,\n                 with_expand_conv=True,\n                 conv_cfg=None,\n                 norm_cfg=dict(type='BN'),\n                 act_cfg=dict(type='ReLU'),\n                 with_cp=False):\n        super(InvertedResidualV3, self).__init__()\n        self.with_res_shortcut = (stride == 1 and in_channels == out_channels)\n        assert stride in [1, 2]\n        self.with_cp = with_cp\n        self.with_se = se_cfg is not None\n        self.with_expand_conv = with_expand_conv\n\n        if self.with_se:\n            assert isinstance(se_cfg, dict)\n        if not self.with_expand_conv:\n            assert mid_channels == in_channels\n\n        if self.with_expand_conv:\n            self.expand_conv = ConvModule(\n                in_channels=in_channels,\n                out_channels=mid_channels,\n                kernel_size=1,\n                stride=1,\n                padding=0,\n                conv_cfg=conv_cfg,\n                norm_cfg=norm_cfg,\n                act_cfg=act_cfg)\n        self.depthwise_conv = ConvModule(\n            in_channels=mid_channels,\n            out_channels=mid_channels,\n            kernel_size=kernel_size,\n            stride=stride,\n            padding=kernel_size // 2,\n            groups=mid_channels,\n            conv_cfg=dict(\n                type='Conv2dAdaptivePadding') if stride == 2 else conv_cfg,\n            norm_cfg=norm_cfg,\n            act_cfg=act_cfg)\n\n        if self.with_se:\n            self.se = SELayer(**se_cfg)\n\n        self.linear_conv = ConvModule(\n            in_channels=mid_channels,\n            out_channels=out_channels,\n            kernel_size=1,\n            stride=1,\n            padding=0,\n            conv_cfg=conv_cfg,\n            norm_cfg=norm_cfg,\n            act_cfg=None)\n\n    def forward(self, x):\n\n        def _inner_forward(x):\n            out = x\n\n            if self.with_expand_conv:\n                out = self.expand_conv(out)\n\n            out = self.depthwise_conv(out)\n\n            if self.with_se:\n                out = self.se(out)\n\n            out = self.linear_conv(out)\n\n            if self.with_res_shortcut:\n                return x + out\n            else:\n                return out\n\n        if self.with_cp and x.requires_grad:\n            out = cp.checkpoint(_inner_forward, x)\n        else:\n            out = _inner_forward(x)\n\n        return out\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/mmseg/models/utils/make_divisible.py",
    "content": "def make_divisible(value, divisor, min_value=None, min_ratio=0.9):\n    \"\"\"Make divisible function.\n\n    This function rounds the channel number to the nearest value that can be\n    divisible by the divisor. It is taken from the original tf repo. It ensures\n    that all layers have a channel number that is divisible by divisor. It can\n    be seen here: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py  # noqa\n\n    Args:\n        value (int): The original channel number.\n        divisor (int): The divisor to fully divide the channel number.\n        min_value (int): The minimum value of the output channel.\n            Default: None, means that the minimum value equal to the divisor.\n        min_ratio (float): The minimum ratio of the rounded channel number to\n            the original channel number. Default: 0.9.\n\n    Returns:\n        int: The modified output channel number.\n    \"\"\"\n\n    if min_value is None:\n        min_value = divisor\n    new_value = max(min_value, int(value + divisor / 2) // divisor * divisor)\n    # Make sure that round down does not go down by more than (1-min_ratio).\n    if new_value < min_ratio * value:\n        new_value += divisor\n    return new_value\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/mmseg/models/utils/res_layer.py",
    "content": "from mmcv.cnn import build_conv_layer, build_norm_layer\nfrom torch import nn as nn\n\n\nclass ResLayer(nn.Sequential):\n    \"\"\"ResLayer to build ResNet style backbone.\n\n    Args:\n        block (nn.Module): block used to build ResLayer.\n        inplanes (int): inplanes of block.\n        planes (int): planes of block.\n        num_blocks (int): number of blocks.\n        stride (int): stride of the first block. Default: 1\n        avg_down (bool): Use AvgPool instead of stride conv when\n            downsampling in the bottleneck. Default: False\n        conv_cfg (dict): dictionary to construct and config conv layer.\n            Default: None\n        norm_cfg (dict): dictionary to construct and config norm layer.\n            Default: dict(type='BN')\n        multi_grid (int | None): Multi grid dilation rates of last\n            stage. Default: None\n        contract_dilation (bool): Whether contract first dilation of each layer\n            Default: False\n    \"\"\"\n\n    def __init__(self,\n                 block,\n                 inplanes,\n                 planes,\n                 num_blocks,\n                 stride=1,\n                 dilation=1,\n                 avg_down=False,\n                 conv_cfg=None,\n                 norm_cfg=dict(type='BN'),\n                 multi_grid=None,\n                 contract_dilation=False,\n                 **kwargs):\n        self.block = block\n\n        downsample = None\n        if stride != 1 or inplanes != planes * block.expansion:\n            downsample = []\n            conv_stride = stride\n            if avg_down:\n                conv_stride = 1\n                downsample.append(\n                    nn.AvgPool2d(\n                        kernel_size=stride,\n                        stride=stride,\n                        ceil_mode=True,\n                        count_include_pad=False))\n            downsample.extend([\n                build_conv_layer(\n                    conv_cfg,\n                    inplanes,\n                    planes * block.expansion,\n                    kernel_size=1,\n                    stride=conv_stride,\n                    bias=False),\n                build_norm_layer(norm_cfg, planes * block.expansion)[1]\n            ])\n            downsample = nn.Sequential(*downsample)\n\n        layers = []\n        if multi_grid is None:\n            if dilation > 1 and contract_dilation:\n                first_dilation = dilation // 2\n            else:\n                first_dilation = dilation\n        else:\n            first_dilation = multi_grid[0]\n        layers.append(\n            block(\n                inplanes=inplanes,\n                planes=planes,\n                stride=stride,\n                dilation=first_dilation,\n                downsample=downsample,\n                conv_cfg=conv_cfg,\n                norm_cfg=norm_cfg,\n                **kwargs))\n        inplanes = planes * block.expansion\n        for i in range(1, num_blocks):\n            layers.append(\n                block(\n                    inplanes=inplanes,\n                    planes=planes,\n                    stride=1,\n                    dilation=dilation if multi_grid is None else multi_grid[i],\n                    conv_cfg=conv_cfg,\n                    norm_cfg=norm_cfg,\n                    **kwargs))\n        super(ResLayer, self).__init__(*layers)\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/mmseg/models/utils/se_layer.py",
    "content": "import mmcv\nimport torch.nn as nn\nfrom mmcv.cnn import ConvModule\n\nfrom .make_divisible import make_divisible\n\n\nclass SELayer(nn.Module):\n    \"\"\"Squeeze-and-Excitation Module.\n\n    Args:\n        channels (int): The input (and output) channels of the SE layer.\n        ratio (int): Squeeze ratio in SELayer, the intermediate channel will be\n            ``int(channels/ratio)``. Default: 16.\n        conv_cfg (None or dict): Config dict for convolution layer.\n            Default: None, which means using conv2d.\n        act_cfg (dict or Sequence[dict]): Config dict for activation layer.\n            If act_cfg is a dict, two activation layers will be configurated\n            by this dict. If act_cfg is a sequence of dicts, the first\n            activation layer will be configurated by the first dict and the\n            second activation layer will be configurated by the second dict.\n            Default: (dict(type='ReLU'), dict(type='HSigmoid', bias=3.0,\n            divisor=6.0)).\n    \"\"\"\n\n    def __init__(self,\n                 channels,\n                 ratio=16,\n                 conv_cfg=None,\n                 act_cfg=(dict(type='ReLU'),\n                          dict(type='HSigmoid', bias=3.0, divisor=6.0))):\n        super(SELayer, self).__init__()\n        if isinstance(act_cfg, dict):\n            act_cfg = (act_cfg, act_cfg)\n        assert len(act_cfg) == 2\n        assert mmcv.is_tuple_of(act_cfg, dict)\n        self.global_avgpool = nn.AdaptiveAvgPool2d(1)\n        self.conv1 = ConvModule(\n            in_channels=channels,\n            out_channels=make_divisible(channels // ratio, 8),\n            kernel_size=1,\n            stride=1,\n            conv_cfg=conv_cfg,\n            act_cfg=act_cfg[0])\n        self.conv2 = ConvModule(\n            in_channels=make_divisible(channels // ratio, 8),\n            out_channels=channels,\n            kernel_size=1,\n            stride=1,\n            conv_cfg=conv_cfg,\n            act_cfg=act_cfg[1])\n\n    def forward(self, x):\n        out = self.global_avgpool(x)\n        out = self.conv1(out)\n        out = self.conv2(out)\n        return x * out\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/mmseg/models/utils/self_attention_block.py",
    "content": "import torch\nfrom mmcv.cnn import ConvModule, constant_init\nfrom torch import nn as nn\nfrom torch.nn import functional as F\n\n\nclass SelfAttentionBlock(nn.Module):\n    \"\"\"General self-attention block/non-local block.\n\n    Please refer to https://arxiv.org/abs/1706.03762 for details about key,\n    query and value.\n\n    Args:\n        key_in_channels (int): Input channels of key feature.\n        query_in_channels (int): Input channels of query feature.\n        channels (int): Output channels of key/query transform.\n        out_channels (int): Output channels.\n        share_key_query (bool): Whether share projection weight between key\n            and query projection.\n        query_downsample (nn.Module): Query downsample module.\n        key_downsample (nn.Module): Key downsample module.\n        key_query_num_convs (int): Number of convs for key/query projection.\n        value_num_convs (int): Number of convs for value projection.\n        matmul_norm (bool): Whether normalize attention map with sqrt of\n            channels\n        with_out (bool): Whether use out projection.\n        conv_cfg (dict|None): Config of conv layers.\n        norm_cfg (dict|None): Config of norm layers.\n        act_cfg (dict|None): Config of activation layers.\n    \"\"\"\n\n    def __init__(self, key_in_channels, query_in_channels, channels,\n                 out_channels, share_key_query, query_downsample,\n                 key_downsample, key_query_num_convs, value_out_num_convs,\n                 key_query_norm, value_out_norm, matmul_norm, with_out,\n                 conv_cfg, norm_cfg, act_cfg):\n        super(SelfAttentionBlock, self).__init__()\n        if share_key_query:\n            assert key_in_channels == query_in_channels\n        self.key_in_channels = key_in_channels\n        self.query_in_channels = query_in_channels\n        self.out_channels = out_channels\n        self.channels = channels\n        self.share_key_query = share_key_query\n        self.conv_cfg = conv_cfg\n        self.norm_cfg = norm_cfg\n        self.act_cfg = act_cfg\n        self.key_project = self.build_project(\n            key_in_channels,\n            channels,\n            num_convs=key_query_num_convs,\n            use_conv_module=key_query_norm,\n            conv_cfg=conv_cfg,\n            norm_cfg=norm_cfg,\n            act_cfg=act_cfg)\n        if share_key_query:\n            self.query_project = self.key_project\n        else:\n            self.query_project = self.build_project(\n                query_in_channels,\n                channels,\n                num_convs=key_query_num_convs,\n                use_conv_module=key_query_norm,\n                conv_cfg=conv_cfg,\n                norm_cfg=norm_cfg,\n                act_cfg=act_cfg)\n        self.value_project = self.build_project(\n            key_in_channels,\n            channels if with_out else out_channels,\n            num_convs=value_out_num_convs,\n            use_conv_module=value_out_norm,\n            conv_cfg=conv_cfg,\n            norm_cfg=norm_cfg,\n            act_cfg=act_cfg)\n        if with_out:\n            self.out_project = self.build_project(\n                channels,\n                out_channels,\n                num_convs=value_out_num_convs,\n                use_conv_module=value_out_norm,\n                conv_cfg=conv_cfg,\n                norm_cfg=norm_cfg,\n                act_cfg=act_cfg)\n        else:\n            self.out_project = None\n\n        self.query_downsample = query_downsample\n        self.key_downsample = key_downsample\n        self.matmul_norm = matmul_norm\n\n        self.init_weights()\n\n    def init_weights(self):\n        \"\"\"Initialize weight of later layer.\"\"\"\n        if self.out_project is not None:\n            if not isinstance(self.out_project, ConvModule):\n                constant_init(self.out_project, 0)\n\n    def build_project(self, in_channels, channels, num_convs, use_conv_module,\n                      conv_cfg, norm_cfg, act_cfg):\n        \"\"\"Build projection layer for key/query/value/out.\"\"\"\n        if use_conv_module:\n            convs = [\n                ConvModule(\n                    in_channels,\n                    channels,\n                    1,\n                    conv_cfg=conv_cfg,\n                    norm_cfg=norm_cfg,\n                    act_cfg=act_cfg)\n            ]\n            for _ in range(num_convs - 1):\n                convs.append(\n                    ConvModule(\n                        channels,\n                        channels,\n                        1,\n                        conv_cfg=conv_cfg,\n                        norm_cfg=norm_cfg,\n                        act_cfg=act_cfg))\n        else:\n            convs = [nn.Conv2d(in_channels, channels, 1)]\n            for _ in range(num_convs - 1):\n                convs.append(nn.Conv2d(channels, channels, 1))\n        if len(convs) > 1:\n            convs = nn.Sequential(*convs)\n        else:\n            convs = convs[0]\n        return convs\n\n    def forward(self, query_feats, key_feats):\n        \"\"\"Forward function.\"\"\"\n        batch_size = query_feats.size(0)\n        query = self.query_project(query_feats)\n        if self.query_downsample is not None:\n            query = self.query_downsample(query)\n        query = query.reshape(*query.shape[:2], -1)\n        query = query.permute(0, 2, 1).contiguous()\n\n        key = self.key_project(key_feats)\n        value = self.value_project(key_feats)\n        if self.key_downsample is not None:\n            key = self.key_downsample(key)\n            value = self.key_downsample(value)\n        key = key.reshape(*key.shape[:2], -1)\n        value = value.reshape(*value.shape[:2], -1)\n        value = value.permute(0, 2, 1).contiguous()\n\n        sim_map = torch.matmul(query, key)\n        if self.matmul_norm:\n            sim_map = (self.channels**-.5) * sim_map\n        sim_map = F.softmax(sim_map, dim=-1)\n\n        context = torch.matmul(sim_map, value)\n        context = context.permute(0, 2, 1).contiguous()\n        context = context.reshape(batch_size, -1, *query_feats.shape[2:])\n        if self.out_project is not None:\n            context = self.out_project(context)\n        return context\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/mmseg/models/utils/up_conv_block.py",
    "content": "import torch\nimport torch.nn as nn\nfrom mmcv.cnn import ConvModule, build_upsample_layer\n\n\nclass UpConvBlock(nn.Module):\n    \"\"\"Upsample convolution block in decoder for UNet.\n\n    This upsample convolution block consists of one upsample module\n    followed by one convolution block. The upsample module expands the\n    high-level low-resolution feature map and the convolution block fuses\n    the upsampled high-level low-resolution feature map and the low-level\n    high-resolution feature map from encoder.\n\n    Args:\n        conv_block (nn.Sequential): Sequential of convolutional layers.\n        in_channels (int): Number of input channels of the high-level\n        skip_channels (int): Number of input channels of the low-level\n        high-resolution feature map from encoder.\n        out_channels (int): Number of output channels.\n        num_convs (int): Number of convolutional layers in the conv_block.\n            Default: 2.\n        stride (int): Stride of convolutional layer in conv_block. Default: 1.\n        dilation (int): Dilation rate of convolutional layer in conv_block.\n            Default: 1.\n        with_cp (bool): Use checkpoint or not. Using checkpoint will save some\n            memory while slowing down the training speed. Default: False.\n        conv_cfg (dict | None): Config dict for convolution layer.\n            Default: None.\n        norm_cfg (dict | None): Config dict for normalization layer.\n            Default: dict(type='BN').\n        act_cfg (dict | None): Config dict for activation layer in ConvModule.\n            Default: dict(type='ReLU').\n        upsample_cfg (dict): The upsample config of the upsample module in\n            decoder. Default: dict(type='InterpConv'). If the size of\n            high-level feature map is the same as that of skip feature map\n            (low-level feature map from encoder), it does not need upsample the\n            high-level feature map and the upsample_cfg is None.\n        dcn (bool): Use deformable convoluton in convolutional layer or not.\n            Default: None.\n        plugins (dict): plugins for convolutional layers. Default: None.\n    \"\"\"\n\n    def __init__(self,\n                 conv_block,\n                 in_channels,\n                 skip_channels,\n                 out_channels,\n                 num_convs=2,\n                 stride=1,\n                 dilation=1,\n                 with_cp=False,\n                 conv_cfg=None,\n                 norm_cfg=dict(type='BN'),\n                 act_cfg=dict(type='ReLU'),\n                 upsample_cfg=dict(type='InterpConv'),\n                 dcn=None,\n                 plugins=None):\n        super(UpConvBlock, self).__init__()\n        assert dcn is None, 'Not implemented yet.'\n        assert plugins is None, 'Not implemented yet.'\n\n        self.conv_block = conv_block(\n            in_channels=2 * skip_channels,\n            out_channels=out_channels,\n            num_convs=num_convs,\n            stride=stride,\n            dilation=dilation,\n            with_cp=with_cp,\n            conv_cfg=conv_cfg,\n            norm_cfg=norm_cfg,\n            act_cfg=act_cfg,\n            dcn=None,\n            plugins=None)\n        if upsample_cfg is not None:\n            self.upsample = build_upsample_layer(\n                cfg=upsample_cfg,\n                in_channels=in_channels,\n                out_channels=skip_channels,\n                with_cp=with_cp,\n                norm_cfg=norm_cfg,\n                act_cfg=act_cfg)\n        else:\n            self.upsample = ConvModule(\n                in_channels,\n                skip_channels,\n                kernel_size=1,\n                stride=1,\n                padding=0,\n                conv_cfg=conv_cfg,\n                norm_cfg=norm_cfg,\n                act_cfg=act_cfg)\n\n    def forward(self, skip, x):\n        \"\"\"Forward function.\"\"\"\n\n        x = self.upsample(x)\n        out = torch.cat([skip, x], dim=1)\n        out = self.conv_block(out)\n\n        return out\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/mmseg/ops/__init__.py",
    "content": "from .encoding import Encoding\nfrom .wrappers import Upsample, resize\n\n__all__ = ['Upsample', 'resize', 'Encoding']\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/mmseg/ops/encoding.py",
    "content": "import torch\nfrom torch import nn as nn\nfrom torch.nn import functional as F\n\n\nclass Encoding(nn.Module):\n    \"\"\"Encoding Layer: a learnable residual encoder.\n\n    Input is of shape  (batch_size, channels, height, width).\n    Output is of shape (batch_size, num_codes, channels).\n\n    Args:\n        channels: dimension of the features or feature channels\n        num_codes: number of code words\n    \"\"\"\n\n    def __init__(self, channels, num_codes):\n        super(Encoding, self).__init__()\n        # init codewords and smoothing factor\n        self.channels, self.num_codes = channels, num_codes\n        std = 1. / ((num_codes * channels)**0.5)\n        # [num_codes, channels]\n        self.codewords = nn.Parameter(\n            torch.empty(num_codes, channels,\n                        dtype=torch.float).uniform_(-std, std),\n            requires_grad=True)\n        # [num_codes]\n        self.scale = nn.Parameter(\n            torch.empty(num_codes, dtype=torch.float).uniform_(-1, 0),\n            requires_grad=True)\n\n    @staticmethod\n    def scaled_l2(x, codewords, scale):\n        num_codes, channels = codewords.size()\n        batch_size = x.size(0)\n        reshaped_scale = scale.view((1, 1, num_codes))\n        expanded_x = x.unsqueeze(2).expand(\n            (batch_size, x.size(1), num_codes, channels))\n        reshaped_codewords = codewords.view((1, 1, num_codes, channels))\n\n        scaled_l2_norm = reshaped_scale * (\n            expanded_x - reshaped_codewords).pow(2).sum(dim=3)\n        return scaled_l2_norm\n\n    @staticmethod\n    def aggregate(assigment_weights, x, codewords):\n        num_codes, channels = codewords.size()\n        reshaped_codewords = codewords.view((1, 1, num_codes, channels))\n        batch_size = x.size(0)\n\n        expanded_x = x.unsqueeze(2).expand(\n            (batch_size, x.size(1), num_codes, channels))\n        encoded_feat = (assigment_weights.unsqueeze(3) *\n                        (expanded_x - reshaped_codewords)).sum(dim=1)\n        return encoded_feat\n\n    def forward(self, x):\n        assert x.dim() == 4 and x.size(1) == self.channels\n        # [batch_size, channels, height, width]\n        batch_size = x.size(0)\n        # [batch_size, height x width, channels]\n        x = x.view(batch_size, self.channels, -1).transpose(1, 2).contiguous()\n        # assignment_weights: [batch_size, channels, num_codes]\n        assigment_weights = F.softmax(\n            self.scaled_l2(x, self.codewords, self.scale), dim=2)\n        # aggregate\n        encoded_feat = self.aggregate(assigment_weights, x, self.codewords)\n        return encoded_feat\n\n    def __repr__(self):\n        repr_str = self.__class__.__name__\n        repr_str += f'(Nx{self.channels}xHxW =>Nx{self.num_codes}' \\\n                    f'x{self.channels})'\n        return repr_str\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/mmseg/ops/wrappers.py",
    "content": "import warnings\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\ndef resize(input,\n           size=None,\n           scale_factor=None,\n           mode='nearest',\n           align_corners=None,\n           warning=True):\n    if warning:\n        if size is not None and align_corners:\n            input_h, input_w = tuple(int(x) for x in input.shape[2:])\n            output_h, output_w = tuple(int(x) for x in size)\n            if output_h > input_h or output_w > output_h:\n                if ((output_h > 1 and output_w > 1 and input_h > 1\n                     and input_w > 1) and (output_h - 1) % (input_h - 1)\n                        and (output_w - 1) % (input_w - 1)):\n                    warnings.warn(\n                        f'When align_corners={align_corners}, '\n                        'the output would more aligned if '\n                        f'input size {(input_h, input_w)} is `x+1` and '\n                        f'out size {(output_h, output_w)} is `nx+1`')\n    if isinstance(size, torch.Size):\n        size = tuple(int(x) for x in size)\n    return F.interpolate(input, size, scale_factor, mode, align_corners)\n\n\nclass Upsample(nn.Module):\n\n    def __init__(self,\n                 size=None,\n                 scale_factor=None,\n                 mode='nearest',\n                 align_corners=None):\n        super(Upsample, self).__init__()\n        self.size = size\n        if isinstance(scale_factor, tuple):\n            self.scale_factor = tuple(float(factor) for factor in scale_factor)\n        else:\n            self.scale_factor = float(scale_factor) if scale_factor else None\n        self.mode = mode\n        self.align_corners = align_corners\n\n    def forward(self, x):\n        if not self.size:\n            size = [int(t * self.scale_factor) for t in x.shape[-2:]]\n        else:\n            size = self.size\n        return resize(x, size, None, self.mode, self.align_corners)\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/mmseg/utils/__init__.py",
    "content": "from .collect_env import collect_env\nfrom .logger import get_root_logger\n\n__all__ = ['get_root_logger', 'collect_env']\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/mmseg/utils/collect_env.py",
    "content": "from mmcv.utils import collect_env as collect_base_env\nfrom mmcv.utils import get_git_hash\n\nimport mmseg\n\n\ndef collect_env():\n    \"\"\"Collect the information of the running environments.\"\"\"\n    env_info = collect_base_env()\n    env_info['MMSegmentation'] = f'{mmseg.__version__}+{get_git_hash()[:7]}'\n\n    return env_info\n\n\nif __name__ == '__main__':\n    for name, val in collect_env().items():\n        print('{}: {}'.format(name, val))\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/mmseg/utils/logger.py",
    "content": "import logging\n\nfrom mmcv.utils import get_logger\n\n\ndef get_root_logger(log_file=None, log_level=logging.INFO):\n    \"\"\"Get the root logger.\n\n    The logger will be initialized if it has not been initialized. By default a\n    StreamHandler will be added. If `log_file` is specified, a FileHandler will\n    also be added. The name of the root logger is the top-level package name,\n    e.g., \"mmseg\".\n\n    Args:\n        log_file (str | None): The log filename. If specified, a FileHandler\n            will be added to the root logger.\n        log_level (int): The root logger level. Note that only the process of\n            rank 0 is affected, while other processes will set the level to\n            \"Error\" and be silent most of the time.\n\n    Returns:\n        logging.Logger: The root logger.\n    \"\"\"\n\n    logger = get_logger(name='mmseg', log_file=log_file, log_level=log_level)\n\n    return logger\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/mmseg/version.py",
    "content": "# Copyright (c) Open-MMLab. All rights reserved.\n\n__version__ = '0.11.0'\n\n\ndef parse_version_info(version_str):\n    version_info = []\n    for x in version_str.split('.'):\n        if x.isdigit():\n            version_info.append(int(x))\n        elif x.find('rc') != -1:\n            patch_version = x.split('rc')\n            version_info.append(int(patch_version[0]))\n            version_info.append(f'rc{patch_version[1]}')\n    return tuple(version_info)\n\n\nversion_info = parse_version_info(__version__)\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/tools/dist_test.sh",
    "content": "#!/usr/bin/env bash\nCONFIG=$1\nCHECKPOINT=$2\nGPUS=$3\nPORT=7956\n\nOMP_NUM_THREADS=1 python -m torch.distributed.launch \\\n    --nproc_per_node=$GPUS \\\n    --master_port=$PORT \\\n    tools/test.py $CONFIG $CHECKPOINT --launcher pytorch ${@:4}\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/tools/dist_train.sh",
    "content": "#!/usr/bin/env bash\nCONFIG=$1\nGPUS=$2\nPORT=7956\n\nOMP_NUM_THREADS=1 python -m torch.distributed.launch \\\n    --nproc_per_node=$GPUS \\\n    --master_port=$PORT \\\n    tools/train.py $CONFIG --launcher pytorch ${@:3} \\\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/tools/test.py",
    "content": "import argparse\nimport os\n\nimport mmcv\nimport torch\nfrom mmcv.parallel import MMDataParallel, MMDistributedDataParallel\nfrom mmcv.runner import get_dist_info, init_dist, load_checkpoint\nfrom mmcv.utils import DictAction\n\nfrom mmseg.apis import multi_gpu_test, single_gpu_test\nfrom mmseg.datasets import build_dataloader, build_dataset\nfrom mmseg.models import build_segmentor\n\nfrom backbone import beit\nfrom backbone import cae\n\ndef parse_args():\n    parser = argparse.ArgumentParser(\n        description='mmseg test (and eval) a model')\n    parser.add_argument('config', help='test config file path')\n    parser.add_argument('checkpoint', help='checkpoint file')\n    parser.add_argument(\n        '--aug-test', action='store_true', help='Use Flip and Multi scale aug')\n    parser.add_argument('--out', help='output result file in pickle format')\n    parser.add_argument(\n        '--format-only',\n        action='store_true',\n        help='Format the output results without perform evaluation. It is'\n        'useful when you want to format the result to a specific format and '\n        'submit it to the test server')\n    parser.add_argument(\n        '--eval',\n        type=str,\n        nargs='+',\n        help='evaluation metrics, which depends on the dataset, e.g., \"mIoU\"'\n        ' for generic datasets, and \"cityscapes\" for Cityscapes')\n    parser.add_argument('--show', action='store_true', help='show results')\n    parser.add_argument(\n        '--show-dir', help='directory where painted images will be saved')\n    parser.add_argument(\n        '--gpu-collect',\n        action='store_true',\n        help='whether to use gpu to collect results.')\n    parser.add_argument(\n        '--tmpdir',\n        help='tmp directory used for collecting results from multiple '\n        'workers, available when gpu_collect is not specified')\n    parser.add_argument(\n        '--options', nargs='+', action=DictAction, help='custom options')\n    parser.add_argument(\n        '--eval-options',\n        nargs='+',\n        action=DictAction,\n        help='custom options for evaluation')\n    parser.add_argument(\n        '--launcher',\n        choices=['none', 'pytorch', 'slurm', 'mpi'],\n        default='none',\n        help='job launcher')\n    parser.add_argument('--local_rank', type=int, default=0)\n    args = parser.parse_args()\n    if 'LOCAL_RANK' not in os.environ:\n        os.environ['LOCAL_RANK'] = str(args.local_rank)\n    return args\n\n\ndef main():\n    args = parse_args()\n\n    assert args.out or args.eval or args.format_only or args.show \\\n        or args.show_dir, \\\n        ('Please specify at least one operation (save/eval/format/show the '\n         'results / save the results) with the argument \"--out\", \"--eval\"'\n         ', \"--format-only\", \"--show\" or \"--show-dir\"')\n\n    if args.eval and args.format_only:\n        raise ValueError('--eval and --format_only cannot be both specified')\n\n    if args.out is not None and not args.out.endswith(('.pkl', '.pickle')):\n        raise ValueError('The output file must be a pkl file.')\n\n    cfg = mmcv.Config.fromfile(args.config)\n    if args.options is not None:\n        cfg.merge_from_dict(args.options)\n    # set cudnn_benchmark\n    if cfg.get('cudnn_benchmark', False):\n        torch.backends.cudnn.benchmark = True\n    if args.aug_test:\n        # hard code index\n        cfg.data.test.pipeline[1].img_ratios = [\n            0.5, 0.75, 1.0, 1.25, 1.5, 1.75\n        ]\n        cfg.data.test.pipeline[1].flip = True\n    cfg.model.pretrained = None\n    cfg.data.test.test_mode = True\n\n    # init distributed env first, since logger depends on the dist info.\n    if args.launcher == 'none':\n        distributed = False\n    else:\n        distributed = True\n        init_dist(args.launcher, **cfg.dist_params)\n\n    # build the dataloader\n    # TODO: support multiple images per gpu (only minor changes are needed)\n    dataset = build_dataset(cfg.data.test)\n    data_loader = build_dataloader(\n        dataset,\n        samples_per_gpu=1,\n        workers_per_gpu=cfg.data.workers_per_gpu,\n        dist=distributed,\n        shuffle=False)\n\n    # build the model and load checkpoint\n    cfg.model.train_cfg = None\n    model = build_segmentor(cfg.model, test_cfg=cfg.get('test_cfg'))\n    checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu')\n    model.CLASSES = checkpoint['meta']['CLASSES']\n    model.PALETTE = checkpoint['meta']['PALETTE']\n\n    efficient_test = False\n    if args.eval_options is not None:\n        efficient_test = args.eval_options.get('efficient_test', False)\n\n    if not distributed:\n        model = MMDataParallel(model, device_ids=[0])\n        outputs = single_gpu_test(model, data_loader, args.show, args.show_dir,\n                                  efficient_test)\n    else:\n        model = MMDistributedDataParallel(\n            model.cuda(),\n            device_ids=[torch.cuda.current_device()],\n            broadcast_buffers=False)\n        outputs = multi_gpu_test(model, data_loader, args.tmpdir,\n                                 args.gpu_collect, efficient_test)\n\n    rank, _ = get_dist_info()\n    if rank == 0:\n        if args.out:\n            print(f'\\nwriting results to {args.out}')\n            mmcv.dump(outputs, args.out)\n        kwargs = {} if args.eval_options is None else args.eval_options\n        if args.format_only:\n            dataset.format_results(outputs, **kwargs)\n        if args.eval:\n            dataset.evaluate(outputs, args.eval, **kwargs)\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "downstream_tasks/semantic_segmentation/tools/train.py",
    "content": "import argparse\nimport copy\nimport os\nimport os.path as osp\nimport time\n\nimport mmcv\nimport mmcv_custom\nimport torch\nfrom mmcv.runner import init_dist\nfrom mmcv.utils import Config, DictAction, get_git_hash\n\nfrom mmseg import __version__\nfrom mmseg.apis import set_random_seed\nfrom mmcv_custom import train_segmentor\nfrom mmseg.datasets import build_dataset\nfrom mmseg.models import build_segmentor\nfrom mmseg.utils import collect_env, get_root_logger\n\nfrom backbone import beit\nfrom backbone import mae\nfrom backbone import beit_fapn\nfrom backbone import cae\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description='Train a segmentor')\n    parser.add_argument('config', help='train config file path')\n    parser.add_argument('--work-dir', help='the dir to save logs and models')\n    parser.add_argument(\n        '--load-from', help='the checkpoint file to load weights from')\n    parser.add_argument(\n        '--resume-from', help='the checkpoint file to resume from')\n    parser.add_argument(\n        '--no-validate',\n        action='store_true',\n        help='whether not to evaluate the checkpoint during training')\n    group_gpus = parser.add_mutually_exclusive_group()\n    group_gpus.add_argument(\n        '--gpus',\n        type=int,\n        help='number of gpus to use '\n        '(only applicable to non-distributed training)')\n    group_gpus.add_argument(\n        '--gpu-ids',\n        type=int,\n        nargs='+',\n        help='ids of gpus to use '\n        '(only applicable to non-distributed training)')\n    parser.add_argument('--seed', type=int, default=None, help='random seed')\n    parser.add_argument(\n        '--deterministic',\n        action='store_true',\n        help='whether to set deterministic options for CUDNN backend.')\n    parser.add_argument(\n        '--options', nargs='+', action=DictAction, help='custom options')\n    parser.add_argument(\n        '--launcher',\n        choices=['none', 'pytorch', 'slurm', 'mpi'],\n        default='none',\n        help='job launcher')\n    parser.add_argument('--local_rank', type=int, default=0)\n    args = parser.parse_args()\n    if 'LOCAL_RANK' not in os.environ:\n        os.environ['LOCAL_RANK'] = str(args.local_rank)\n\n    return args\n\n\ndef main():\n    args = parse_args()\n\n    cfg = Config.fromfile(args.config)\n    if args.options is not None:\n        cfg.merge_from_dict(args.options)\n    # set cudnn_benchmark\n    if cfg.get('cudnn_benchmark', False):\n        torch.backends.cudnn.benchmark = True\n\n    # work_dir is determined in this priority: CLI > segment in file > filename\n    if args.work_dir is not None:\n        # update configs according to CLI args if args.work_dir is not None\n        cfg.work_dir = args.work_dir\n    elif cfg.get('work_dir', None) is None:\n        # use config filename as default work_dir if cfg.work_dir is None\n        cfg.work_dir = osp.join('./work_dirs',\n                                osp.splitext(osp.basename(args.config))[0])\n    if args.load_from is not None:\n        cfg.load_from = args.load_from\n    if args.resume_from is not None:\n        cfg.resume_from = args.resume_from\n    if args.gpu_ids is not None:\n        cfg.gpu_ids = args.gpu_ids\n    else:\n        cfg.gpu_ids = range(1) if args.gpus is None else range(args.gpus)\n\n    # init distributed env first, since logger depends on the dist info.\n    if args.launcher == 'none':\n        distributed = False\n    else:\n        distributed = True\n        init_dist(args.launcher, **cfg.dist_params)\n\n    # create work_dir\n    mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))\n    # dump config\n    cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config)))\n    # init the logger before other steps\n    timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())\n    log_file = osp.join(cfg.work_dir, f'{timestamp}.log')\n    logger = get_root_logger(log_file=log_file, log_level=cfg.log_level)\n\n    # init the meta dict to record some important information such as\n    # environment info and seed, which will be logged\n    meta = dict()\n    # log env info\n    env_info_dict = collect_env()\n    env_info = '\\n'.join([f'{k}: {v}' for k, v in env_info_dict.items()])\n    dash_line = '-' * 60 + '\\n'\n    logger.info('Environment info:\\n' + dash_line + env_info + '\\n' +\n                dash_line)\n    meta['env_info'] = env_info\n\n    # log some basic info\n    logger.info(f'Distributed training: {distributed}')\n    logger.info(f'Config:\\n{cfg.pretty_text}')\n\n    # set random seeds\n    if args.seed is not None:\n        logger.info(f'Set random seed to {args.seed}, deterministic: '\n                    f'{args.deterministic}')\n        set_random_seed(args.seed, deterministic=args.deterministic)\n    cfg.seed = args.seed\n    meta['seed'] = args.seed\n    meta['exp_name'] = osp.basename(args.config)\n\n    model = build_segmentor(\n        cfg.model,\n        train_cfg=cfg.get('train_cfg'),\n        test_cfg=cfg.get('test_cfg'))\n\n    logger.info(model)\n\n    datasets = [build_dataset(cfg.data.train)]\n    if len(cfg.workflow) == 2:\n        val_dataset = copy.deepcopy(cfg.data.val)\n        val_dataset.pipeline = cfg.data.train.pipeline\n        datasets.append(build_dataset(val_dataset))\n    if cfg.checkpoint_config is not None:\n        # save mmseg version, config file content and class names in\n        # checkpoints as meta data\n        cfg.checkpoint_config.meta = dict(\n            mmseg_version=f'{__version__}+{get_git_hash()[:7]}',\n            config=cfg.pretty_text,\n            CLASSES=datasets[0].CLASSES,\n            PALETTE=datasets[0].PALETTE)\n    # add an attribute for visualization convenience\n    model.CLASSES = datasets[0].CLASSES\n    train_segmentor(\n        model,\n        datasets,\n        cfg,\n        distributed=distributed,\n        validate=(not args.no_validate),\n        timestamp=timestamp,\n        meta=meta)\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "furnace/dataset_folder.py",
    "content": "from torchvision.datasets.vision import VisionDataset\n\nfrom PIL import Image\n\nimport os\nimport os.path\nimport random\nfrom typing import Any, Callable, cast, Dict, List, Optional, Tuple\n\n\ndef has_file_allowed_extension(filename: str, extensions: Tuple[str, ...]) -> bool:\n    \"\"\"Checks if a file is an allowed extension.\n\n    Args:\n        filename (string): path to a file\n        extensions (tuple of strings): extensions to consider (lowercase)\n\n    Returns:\n        bool: True if the filename ends with one of given extensions\n    \"\"\"\n    return filename.lower().endswith(extensions)\n\n\ndef is_image_file(filename: str) -> bool:\n    \"\"\"Checks if a file is an allowed image extension.\n\n    Args:\n        filename (string): path to a file\n\n    Returns:\n        bool: True if the filename ends with a known image extension\n    \"\"\"\n    return has_file_allowed_extension(filename, IMG_EXTENSIONS)\n\n\ndef make_dataset(\n    directory: str,\n    class_to_idx: Dict[str, int],\n    extensions: Optional[Tuple[str, ...]] = None,\n    is_valid_file: Optional[Callable[[str], bool]] = None,\n) -> List[Tuple[str, int]]:\n    instances = []\n    directory = os.path.expanduser(directory)\n    both_none = extensions is None and is_valid_file is None\n    both_something = extensions is not None and is_valid_file is not None\n    if both_none or both_something:\n        raise ValueError(\"Both extensions and is_valid_file cannot be None or not None at the same time\")\n    if extensions is not None:\n        def is_valid_file(x: str) -> bool:\n            return has_file_allowed_extension(x, cast(Tuple[str, ...], extensions))\n    is_valid_file = cast(Callable[[str], bool], is_valid_file)\n    for target_class in sorted(class_to_idx.keys()):\n        class_index = class_to_idx[target_class]\n        target_dir = os.path.join(directory, target_class)\n        if not os.path.isdir(target_dir):\n            continue\n        for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):\n            for fname in sorted(fnames):\n                path = os.path.join(root, fname)\n                if is_valid_file(path):\n                    item = path, class_index\n                    instances.append(item)\n    return instances\n\n\nclass DatasetFolder(VisionDataset):\n    \"\"\"A generic data loader where the samples are arranged in this way: ::\n\n        root/class_x/xxx.ext\n        root/class_x/xxy.ext\n        root/class_x/xxz.ext\n\n        root/class_y/123.ext\n        root/class_y/nsdf3.ext\n        root/class_y/asd932_.ext\n\n    Args:\n        root (string): Root directory path.\n        loader (callable): A function to load a sample given its path.\n        extensions (tuple[string]): A list of allowed extensions.\n            both extensions and is_valid_file should not be passed.\n        transform (callable, optional): A function/transform that takes in\n            a sample and returns a transformed version.\n            E.g, ``transforms.RandomCrop`` for images.\n        target_transform (callable, optional): A function/transform that takes\n            in the target and transforms it.\n        is_valid_file (callable, optional): A function that takes path of a file\n            and check if the file is a valid file (used to check of corrupt files)\n            both extensions and is_valid_file should not be passed.\n\n     Attributes:\n        classes (list): List of the class names sorted alphabetically.\n        class_to_idx (dict): Dict with items (class_name, class_index).\n        samples (list): List of (sample path, class_index) tuples\n        targets (list): The class_index value for each image in the dataset\n    \"\"\"\n\n    def __init__(\n            self,\n            root: str,\n            loader: Callable[[str], Any],\n            extensions: Optional[Tuple[str, ...]] = None,\n            transform: Optional[Callable] = None,\n            target_transform: Optional[Callable] = None,\n            is_valid_file: Optional[Callable[[str], bool]] = None,\n    ) -> None:\n        super(DatasetFolder, self).__init__(root, transform=transform,\n                                            target_transform=target_transform)\n        classes, class_to_idx = self._find_classes(self.root)\n        samples = make_dataset(self.root, class_to_idx, extensions, is_valid_file)\n        if len(samples) == 0:\n            msg = \"Found 0 files in subfolders of: {}\\n\".format(self.root)\n            if extensions is not None:\n                msg += \"Supported extensions are: {}\".format(\",\".join(extensions))\n            raise RuntimeError(msg)\n\n        self.loader = loader\n        self.extensions = extensions\n\n        self.classes = classes\n        self.class_to_idx = class_to_idx\n        self.samples = samples\n        self.targets = [s[1] for s in samples]\n\n    def _find_classes(self, dir: str) -> Tuple[List[str], Dict[str, int]]:\n        \"\"\"\n        Finds the class folders in a dataset.\n\n        Args:\n            dir (string): Root directory path.\n\n        Returns:\n            tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary.\n\n        Ensures:\n            No class is a subdirectory of another.\n        \"\"\"\n        classes = [d.name for d in os.scandir(dir) if d.is_dir()]\n        classes.sort()\n        class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}\n        return classes, class_to_idx\n\n    def __getitem__(self, index: int) -> Tuple[Any, Any]:\n        \"\"\"\n        Args:\n            index (int): Index\n\n        Returns:\n            tuple: (sample, target) where target is class_index of the target class.\n        \"\"\"\n        while True:\n            try:\n                path, target = self.samples[index]\n                sample = self.loader(path)\n                break\n            except Exception as e:\n                print(e)\n                index = random.randint(0, len(self.samples) - 1)\n\n        if self.transform is not None:\n            sample = self.transform(sample)\n        if self.target_transform is not None:\n            target = self.target_transform(target)\n\n        return sample, target\n\n    def __len__(self) -> int:\n        return len(self.samples)\n\n\nIMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp')\n\n\ndef pil_loader(path: str) -> Image.Image:\n    # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)\n    with open(path, 'rb') as f:\n        img = Image.open(f)\n        return img.convert('RGB')\n\n\n# TODO: specify the return type\ndef accimage_loader(path: str) -> Any:\n    import accimage\n    try:\n        return accimage.Image(path)\n    except IOError:\n        # Potentially a decoding problem, fall back to PIL.Image\n        return pil_loader(path)\n\n\ndef default_loader(path: str) -> Any:\n    from torchvision import get_image_backend\n    if get_image_backend() == 'accimage':\n        return accimage_loader(path)\n    else:\n        return pil_loader(path)\n\n\nclass ImageFolder(DatasetFolder):\n    \"\"\"A generic data loader where the images are arranged in this way: ::\n\n        root/dog/xxx.png\n        root/dog/xxy.png\n        root/dog/xxz.png\n\n        root/cat/123.png\n        root/cat/nsdf3.png\n        root/cat/asd932_.png\n\n    Args:\n        root (string): Root directory path.\n        transform (callable, optional): A function/transform that  takes in an PIL image\n            and returns a transformed version. E.g, ``transforms.RandomCrop``\n        target_transform (callable, optional): A function/transform that takes in the\n            target and transforms it.\n        loader (callable, optional): A function to load an image given its path.\n        is_valid_file (callable, optional): A function that takes path of an Image file\n            and check if the file is a valid file (used to check of corrupt files)\n\n     Attributes:\n        classes (list): List of the class names sorted alphabetically.\n        class_to_idx (dict): Dict with items (class_name, class_index).\n        imgs (list): List of (image path, class_index) tuples\n    \"\"\"\n\n    def __init__(\n            self,\n            root: str,\n            transform: Optional[Callable] = None,\n            target_transform: Optional[Callable] = None,\n            loader: Callable[[str], Any] = default_loader,\n            is_valid_file: Optional[Callable[[str], bool]] = None,\n    ):\n        super(ImageFolder, self).__init__(root, loader, IMG_EXTENSIONS if is_valid_file is None else None,\n                                          transform=transform,\n                                          target_transform=target_transform,\n                                          is_valid_file=is_valid_file)\n        self.imgs = self.samples\n"
  },
  {
    "path": "furnace/datasets.py",
    "content": "import os\nimport torch\n\nfrom torchvision import datasets, transforms\n\nfrom timm.data.constants import \\\n    IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD\nfrom furnace.transforms import RandomResizedCropAndInterpolationWithTwoPic\nfrom timm.data import create_transform\n\nfrom dall_e.utils import map_pixels\nfrom furnace.masking_generator import MaskingGenerator, RandomMaskingGenerator\nfrom furnace.dataset_folder import ImageFolder\n\ndef preprocess_vqgan(x):\n    x = 2.*x - 1.\n    return x\n\nclass DataAugmentationForCAE(object):\n    def __init__(self, args):\n        imagenet_default_mean_and_std = args.imagenet_default_mean_and_std\n        mean = IMAGENET_INCEPTION_MEAN if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_MEAN\n        std = IMAGENET_INCEPTION_STD if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_STD\n\n        if args.color_jitter > 0:\n            self.common_transform = transforms.Compose([\n                transforms.ColorJitter(args.color_jitter, args.color_jitter, args.color_jitter),\n                transforms.RandomHorizontalFlip(p=0.5),\n                RandomResizedCropAndInterpolationWithTwoPic(\n                    size=args.input_size, second_size=args.second_input_size,\n                    interpolation=args.train_interpolation, second_interpolation=args.second_interpolation,\n                    scale=(args.crop_min_size, args.crop_max_size),\n                ),\n            ])\n        else:\n            self.common_transform = transforms.Compose([\n                transforms.RandomHorizontalFlip(p=0.5),\n                RandomResizedCropAndInterpolationWithTwoPic(\n                    size=args.input_size, second_size=args.second_input_size,\n                    interpolation=args.train_interpolation, second_interpolation=args.second_interpolation,\n                    scale=(args.crop_min_size, args.crop_max_size),\n                ),\n            ])\n\n        self.patch_transform = transforms.Compose([\n            transforms.ToTensor(),\n            transforms.Normalize(\n                mean=torch.tensor(mean),\n                std=torch.tensor(std))\n        ])\n\n        if args.discrete_vae_type == \"dall-e\":\n            self.visual_token_transform = transforms.Compose([\n                transforms.ToTensor(),\n                map_pixels,\n            ])\n        elif args.discrete_vae_type == \"vqgan_gumbel_f8_8192\":\n            self.visual_token_transform = transforms.Compose([\n                transforms.ToTensor(),\n                preprocess_vqgan,\n            ])\n        elif args.discrete_vae_type == \"customized\":\n            self.visual_token_transform = transforms.Compose([\n                transforms.ToTensor(),\n                transforms.Normalize(\n                    mean=IMAGENET_INCEPTION_MEAN,\n                    std=IMAGENET_INCEPTION_STD,\n                ),\n            ])\n        else:\n            raise NotImplementedError()\n        \n        if args.mask_generator == 'block':\n            self.masked_position_generator = MaskingGenerator(\n                args.window_size, num_masking_patches=args.num_mask_patches,\n                max_num_patches=args.max_mask_patches_per_block,\n                min_num_patches=args.min_mask_patches_per_block,\n            )\n        elif args.mask_generator == 'random':\n            self.masked_position_generator = RandomMaskingGenerator(\n                args.window_size, ratio_masking_patches=args.ratio_mask_patches\n            )\n        \n\n    def __call__(self, image):\n        for_patches, for_visual_tokens = self.common_transform(image)\n\n        return \\\n            self.patch_transform(for_patches), self.visual_token_transform(for_visual_tokens), \\\n            self.masked_position_generator()\n\n    def __repr__(self):\n        repr = \"(DataAugmentationForCAE,\\n\"\n        repr += \"  common_transform = %s,\\n\" % str(self.common_transform)\n        repr += \"  patch_transform = %s,\\n\" % str(self.patch_transform)\n        repr += \"  visual_tokens_transform = %s,\\n\" % str(self.visual_token_transform)\n        repr += \"  Masked position generator = %s,\\n\" % str(self.masked_position_generator)\n        repr += \")\"\n        return repr\n\ndef build_cae_pretraining_dataset(args):\n    transform = DataAugmentationForCAE(args)\n    print(\"Data Aug = %s\" % str(transform))\n    return ImageFolder(args.data_path, transform=transform)\n\n\ndef build_dataset(is_train, args):\n    transform = build_transform(is_train, args)\n\n    print(\"Transform = \")\n    if isinstance(transform, tuple):\n        for trans in transform:\n            print(\" - - - - - - - - - - \")\n            for t in trans.transforms:\n                print(t)\n    else:\n        for t in transform.transforms:\n            print(t)\n    print(\"---------------------------\")\n\n    if args.data_set == 'CIFAR':\n        dataset = datasets.CIFAR100(args.data_path, train=is_train, transform=transform)\n        nb_classes = 100\n    elif args.data_set == 'IMNET':\n        root = os.path.join(args.data_path, 'train' if is_train else 'val')\n        dataset = datasets.ImageFolder(root, transform=transform)\n        nb_classes = 1000\n    elif args.data_set == \"image_folder\":\n        root = args.data_path if is_train else args.eval_data_path\n        dataset = ImageFolder(root, transform=transform)\n        nb_classes = args.nb_classes\n        assert len(dataset.class_to_idx) == nb_classes\n    else:\n        raise NotImplementedError()\n    assert nb_classes == args.nb_classes\n    print(\"Number of the class = %d\" % args.nb_classes)\n\n    return dataset, nb_classes\n\n\ndef build_transform(is_train, args):\n    resize_im = args.input_size > 32\n    imagenet_default_mean_and_std = args.imagenet_default_mean_and_std\n    mean = IMAGENET_INCEPTION_MEAN if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_MEAN\n    std = IMAGENET_INCEPTION_STD if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_STD\n\n    if is_train:\n        # this should always dispatch to transforms_imagenet_train\n        transform = create_transform(\n            input_size=args.input_size,\n            is_training=True,\n            color_jitter=args.color_jitter,\n            auto_augment=args.aa,\n            interpolation=args.train_interpolation,\n            re_prob=args.reprob,\n            re_mode=args.remode,\n            re_count=args.recount,\n            mean=mean,\n            std=std,\n        )\n        if not resize_im:\n            # replace RandomResizedCropAndInterpolation with\n            # RandomCrop\n            transform.transforms[0] = transforms.RandomCrop(\n                args.input_size, padding=4)\n        return transform\n\n    t = []\n    if resize_im:\n        if args.crop_pct is None:\n            if args.input_size < 384:\n                args.crop_pct = 224 / 256\n            else:\n                args.crop_pct = 1.0\n        size = int(args.input_size / args.crop_pct)\n        t.append(\n            transforms.Resize(size, interpolation=3),  # to maintain same ratio w.r.t. 224 images\n        )\n        t.append(transforms.CenterCrop(args.input_size))\n\n    t.append(transforms.ToTensor())\n    t.append(transforms.Normalize(mean, std))\n    return transforms.Compose(t)\n"
  },
  {
    "path": "furnace/engine_for_finetuning.py",
    "content": "import math\nimport sys\nimport time\nfrom typing import Iterable, Optional\n\nimport torch\n\nfrom timm.data import Mixup\nfrom timm.utils import accuracy, ModelEma\n\nimport furnace.utils as utils\n\n\ndef train_class_batch(model, samples, target, criterion):\n    outputs = model(samples)\n    loss = criterion(outputs, target)\n    return loss, outputs\n\n\ndef get_loss_scale_for_deepspeed(model):\n    optimizer = model.optimizer\n    return optimizer.loss_scale if hasattr(optimizer, \"loss_scale\") else optimizer.cur_scale\n\n\ndef train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,\n                    data_loader: Iterable, optimizer: torch.optim.Optimizer,\n                    device: torch.device, epoch: int, loss_scaler, max_norm: float = 0,\n                    model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None, log_writer=None,\n                    start_steps=None, lr_schedule_values=None, wd_schedule_values=None,\n                    num_training_steps_per_epoch=None, update_freq=None):\n    model.train(True)\n    metric_logger = utils.MetricLogger(delimiter=\"  \")\n    metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))\n    metric_logger.add_meter('min_lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))\n    header = 'Epoch: [{}]'.format(epoch)\n    print_freq = 10\n\n    if loss_scaler is None:\n        model.zero_grad()\n        model.micro_steps = 0\n    else:\n        optimizer.zero_grad()\n\n    for data_iter_step, (samples, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):\n        step = data_iter_step // update_freq\n        if step >= num_training_steps_per_epoch:\n            continue\n        it = start_steps + step  # global training iteration\n        # Update LR & WD for the first acc\n        if lr_schedule_values is not None or wd_schedule_values is not None and data_iter_step % update_freq == 0:\n            for i, param_group in enumerate(optimizer.param_groups):\n                if lr_schedule_values is not None:\n                    if \"lr_scale\" in param_group:\n                        param_group[\"lr\"] = lr_schedule_values[it] * param_group[\"lr_scale\"]\n                    else:\n                        param_group[\"lr\"] = lr_schedule_values[it]\n                if wd_schedule_values is not None and param_group[\"weight_decay\"] > 0:\n                    param_group[\"weight_decay\"] = wd_schedule_values[it]\n\n        samples = samples.to(device, non_blocking=True)\n        targets = targets.to(device, non_blocking=True)\n\n        if mixup_fn is not None:\n            samples, targets = mixup_fn(samples, targets)\n\n        if loss_scaler is None:\n            samples = samples.half()\n            loss, output = train_class_batch(\n                model, samples, targets, criterion)\n        else:\n            with torch.cuda.amp.autocast():\n                loss, output = train_class_batch(\n                    model, samples, targets, criterion)\n\n        loss_value = loss.item()\n\n        if not math.isfinite(loss_value):\n            print(\"Loss is {}, stopping training\".format(loss_value))\n            sys.exit(1)\n\n        if loss_scaler is None:\n            loss /= update_freq\n            model.backward(loss)\n            model.step()\n\n            if (data_iter_step + 1) % update_freq == 0:\n                # model.zero_grad()\n                # Deepspeed will call step() & model.zero_grad() automatic\n                if model_ema is not None:\n                    model_ema.update(model)\n            grad_norm = None\n            loss_scale_value = get_loss_scale_for_deepspeed(model)\n        else:\n            # this attribute is added by timm on one optimizer (adahessian)\n            is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order\n            loss /= update_freq\n            grad_norm = loss_scaler(loss, optimizer, clip_grad=max_norm,\n                                    parameters=model.parameters(), create_graph=is_second_order,\n                                    update_grad=(data_iter_step + 1) % update_freq == 0)\n            if (data_iter_step + 1) % update_freq == 0:\n                optimizer.zero_grad()\n                if model_ema is not None:\n                    model_ema.update(model)\n            loss_scale_value = loss_scaler.state_dict()[\"scale\"]\n\n        torch.cuda.synchronize()\n\n        if mixup_fn is None:\n            class_acc = (output.max(-1)[-1] == targets).float().mean()\n        else:\n            class_acc = None\n        metric_logger.update(loss=loss_value)\n        metric_logger.update(class_acc=class_acc)\n        metric_logger.update(loss_scale=loss_scale_value)\n        min_lr = 10.\n        max_lr = 0.\n        for group in optimizer.param_groups:\n            min_lr = min(min_lr, group[\"lr\"])\n            max_lr = max(max_lr, group[\"lr\"])\n\n        metric_logger.update(lr=max_lr)\n        metric_logger.update(min_lr=min_lr)\n        weight_decay_value = None\n        for group in optimizer.param_groups:\n            if group[\"weight_decay\"] > 0:\n                weight_decay_value = group[\"weight_decay\"]\n        metric_logger.update(weight_decay=weight_decay_value)\n        metric_logger.update(grad_norm=grad_norm)\n\n        if log_writer is not None:\n            log_writer.update(loss=loss_value, head=\"loss\")\n            log_writer.update(class_acc=class_acc, head=\"loss\")\n            log_writer.update(loss_scale=loss_scale_value, head=\"opt\")\n            log_writer.update(lr=max_lr, head=\"opt\")\n            log_writer.update(min_lr=min_lr, head=\"opt\")\n            log_writer.update(weight_decay=weight_decay_value, head=\"opt\")\n            log_writer.update(grad_norm=grad_norm, head=\"opt\")\n\n            log_writer.set_step()\n\n    # gather the stats from all processes\n    metric_logger.synchronize_between_processes()\n    now_time = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())\n    print(now_time, \"Averaged stats:\", metric_logger)\n    # print(\"Averaged stats:\", metric_logger)\n    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}\n\n\n@torch.no_grad()\ndef evaluate(data_loader, model, device):\n    criterion = torch.nn.CrossEntropyLoss()\n\n    metric_logger = utils.MetricLogger(delimiter=\"  \")\n    header = 'Test:'\n\n    # switch to evaluation mode\n    model.eval()\n\n    for batch in metric_logger.log_every(data_loader, 10, header):\n        images = batch[0]\n        target = batch[-1]\n        images = images.to(device, non_blocking=True)\n        target = target.to(device, non_blocking=True)\n\n        # compute output\n        with torch.cuda.amp.autocast():\n            output = model(images)\n            loss = criterion(output, target)\n\n        acc1, acc5 = accuracy(output, target, topk=(1, 5))\n\n        batch_size = images.shape[0]\n        metric_logger.update(loss=loss.item())\n        metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)\n        metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)\n    # gather the stats from all processes\n    metric_logger.synchronize_between_processes()\n    print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'\n          .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss))\n\n    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}\n"
  },
  {
    "path": "furnace/engine_for_pretraining.py",
    "content": "import math\nimport sys\nimport time\nfrom typing import Iterable\n\nimport torch\nimport torch.nn as nn\n\nimport furnace.utils as utils\nimport torch.nn.functional as F\nfrom timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD\n\ndef loss_selector(loss_type, pred, target):\n    if loss_type == 'mse':\n        return F.mse_loss(pred, target, reduction=\"mean\")\n    elif loss_type == 'kld':\n        return F.kl_div(F.log_softmax(pred, dim=-1), F.softmax(target, dim=-1), reduction='mean')\n\ndef train_one_epoch(model: torch.nn.Module, d_vae: torch.nn.Module,\n                    data_loader: Iterable, optimizer: torch.optim.Optimizer,\n                    device: torch.device, epoch: int, loss_scaler, max_norm: float = 0,\n                    log_writer=None, lr_scheduler=None, start_steps=None,\n                    lr_schedule_values=None, wd_schedule_values=None, args=None):\n    model.train()\n    metric_logger = utils.MetricLogger(delimiter=\"  \")\n    metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))\n    metric_logger.add_meter('min_lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))\n    header = 'Epoch: [{}]'.format(epoch)\n    print_freq = 10\n\n    for step, (batch, _) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):\n        # assign learning rate & weight decay for each step\n        it = start_steps + step  # global training iteration\n        if lr_schedule_values is not None or wd_schedule_values is not None:\n            for i, param_group in enumerate(optimizer.param_groups):\n                if lr_schedule_values is not None:\n                    param_group[\"lr\"] = lr_schedule_values[it] * param_group[\"lr_scale\"]\n                if wd_schedule_values is not None and param_group[\"weight_decay\"] > 0:\n                    param_group[\"weight_decay\"] = wd_schedule_values[it]\n\n        samples, images, bool_masked_pos = batch\n        images = images.to(device, non_blocking=True)\n        samples = samples.to(device, non_blocking=True)\n        bool_masked_pos = bool_masked_pos.to(device, non_blocking=True)\n\n        with torch.no_grad():\n            input_ids = d_vae.get_codebook_indices(images).flatten(1)\n            bool_masked_pos = bool_masked_pos.flatten(1).to(torch.bool)\n            labels = input_ids[bool_masked_pos]\n\n        with torch.cuda.amp.autocast():\n            outputs, latent, latent_target = model(samples, bool_masked_pos=bool_masked_pos, return_all_tokens=False)\n\n            loss_main = nn.CrossEntropyLoss()(input=outputs.float(), target=labels)\n            loss_align = args.align_loss_weight * loss_selector('mse', latent.float(), latent_target.detach().float())\n            loss = loss_main + loss_align\n\n        loss_value = loss.item()\n        loss_main_value = loss_main.item()\n        loss_align_value = loss_align.item()\n\n        if not math.isfinite(loss_value):\n            print(\"Loss is {}, stopping training\".format(loss_value))\n            sys.exit(1)\n\n        optimizer.zero_grad()\n        # this attribute is added by timm on one optimizer (adahessian)\n        is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order\n        grad_norm = loss_scaler(loss, optimizer, clip_grad=max_norm,\n                                parameters=model.parameters(), create_graph=is_second_order)\n        loss_scale_value = loss_scaler.state_dict()[\"scale\"]\n\n        torch.cuda.synchronize()\n\n        mlm_acc = (outputs.max(-1)[1] == labels).float().mean().item()\n        metric_logger.update(mlm_acc=mlm_acc)\n        if log_writer is not None:\n            log_writer.update(mlm_acc=mlm_acc, head=\"loss\")\n\n\n        metric_logger.update(loss=loss_value)\n        metric_logger.update(loss_main=loss_main_value)\n        metric_logger.update(loss_align=loss_align_value)\n        metric_logger.update(loss_scale=loss_scale_value)\n        min_lr = 10.\n        max_lr = 0.\n        for group in optimizer.param_groups:\n            min_lr = min(min_lr, group[\"lr\"])\n            max_lr = max(max_lr, group[\"lr\"])\n\n        metric_logger.update(lr=max_lr)\n        metric_logger.update(min_lr=min_lr)\n        weight_decay_value = None\n        for group in optimizer.param_groups:\n            if group[\"weight_decay\"] > 0:\n                weight_decay_value = group[\"weight_decay\"]\n        metric_logger.update(weight_decay=weight_decay_value)\n        metric_logger.update(grad_norm=grad_norm)\n\n        if log_writer is not None:\n            log_writer.update(loss=loss_value, head=\"loss\")\n            log_writer.update(loss=loss_main_value, head=\"loss_main\")\n            log_writer.update(loss=loss_align_value, head=\"loss_align\")\n            log_writer.update(loss_scale=loss_scale_value, head=\"opt\")\n            log_writer.update(lr=max_lr, head=\"opt\")\n            log_writer.update(min_lr=min_lr, head=\"opt\")\n            log_writer.update(weight_decay=weight_decay_value, head=\"opt\")\n            log_writer.update(grad_norm=grad_norm, head=\"opt\")\n\n            log_writer.set_step()\n\n        if lr_scheduler is not None:\n            lr_scheduler.step_update(start_steps + step)\n    # gather the stats from all processes\n    metric_logger.synchronize_between_processes()\n    now_time = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())\n    print(now_time, \"Averaged stats:\", metric_logger)\n    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}\n"
  },
  {
    "path": "furnace/masking_generator.py",
    "content": "import random\nimport math\nimport numpy as np\n\nclass MaskingGenerator:\n    def __init__(\n            self, input_size, num_masking_patches, min_num_patches=4, max_num_patches=None,\n            min_aspect=0.3, max_aspect=None):\n        if not isinstance(input_size, tuple):\n            input_size = (input_size, ) * 2\n        self.height, self.width = input_size\n\n        self.num_patches = self.height * self.width\n        self.num_masking_patches = num_masking_patches\n\n        self.min_num_patches = min_num_patches\n        self.max_num_patches = num_masking_patches if max_num_patches is None else max_num_patches\n\n        max_aspect = max_aspect or 1 / min_aspect\n        self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect))\n\n    def __repr__(self):\n        repr_str = \"Generator(%d, %d -> [%d ~ %d], max = %d, %.3f ~ %.3f)\" % (\n            self.height, self.width, self.min_num_patches, self.max_num_patches,\n            self.num_masking_patches, self.log_aspect_ratio[0], self.log_aspect_ratio[1])\n        return repr_str\n\n    def get_shape(self):\n        return self.height, self.width\n\n    def _mask(self, mask, max_mask_patches):\n        delta = 0\n        for attempt in range(10):\n            target_area = random.uniform(self.min_num_patches, max_mask_patches)\n            aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))\n            h = int(round(math.sqrt(target_area * aspect_ratio)))\n            w = int(round(math.sqrt(target_area / aspect_ratio)))\n            if w < self.width and h < self.height:\n                top = random.randint(0, self.height - h)\n                left = random.randint(0, self.width - w)\n\n                num_masked = mask[top: top + h, left: left + w].sum()\n                # Overlap\n                if 0 < h * w - num_masked <= max_mask_patches:\n                    for i in range(top, top + h):\n                        for j in range(left, left + w):\n                            if mask[i, j] == 0:\n                                mask[i, j] = 1\n                                delta += 1\n\n                if delta > 0:\n                    break\n        return delta\n\n    def __call__(self):\n        mask = np.zeros(shape=self.get_shape(), dtype=int)\n        mask_count = 0\n        while mask_count != self.num_masking_patches:\n            max_mask_patches = self.num_masking_patches - mask_count\n            max_mask_patches = min(max_mask_patches, self.max_num_patches)\n\n            delta = self._mask(mask, max_mask_patches)\n            mask_count += delta\n        \n        return mask\n\n\nclass RandomMaskingGenerator:\n    def __init__(\n            self, input_size, ratio_masking_patches):\n        if not isinstance(input_size, tuple):\n            input_size = (input_size, ) * 2\n        self.height, self.width = input_size\n\n        self.num_patches = self.height * self.width\n        self.num_masking_patches = int(ratio_masking_patches * self.num_patches)\n\n    def __repr__(self):\n        repr_str = \"Maks: total patches {}, mask patches {}\".format(\n            self.num_patches, self.num_masking_patches\n        )\n        return repr_str\n\n\n    def __call__(self):\n        mask = np.hstack([\n            np.zeros(self.num_patches - self.num_masking_patches),\n            np.ones(self.num_masking_patches),\n        ])\n        np.random.shuffle(mask)\n        \n        return mask\n"
  },
  {
    "path": "furnace/optim_factory.py",
    "content": "import torch\nfrom torch import optim as optim\n\nfrom timm.optim.adafactor import Adafactor\nfrom timm.optim.adahessian import Adahessian\nfrom timm.optim.adamp import AdamP\nfrom timm.optim.lookahead import Lookahead\nfrom timm.optim.nadam import Nadam\nfrom timm.optim.novograd import NovoGrad\nfrom timm.optim.nvnovograd import NvNovoGrad\nfrom timm.optim.radam import RAdam\nfrom timm.optim.rmsprop_tf import RMSpropTF\nfrom timm.optim.sgdp import SGDP\n\nimport json\n\ntry:\n    from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD\n    has_apex = True\nexcept ImportError:\n    has_apex = False\n\n\ndef get_num_layer_for_vit(var_name, num_max_layer):\n    if var_name in (\"cls_token\", \"mask_token\", \"pos_embed\"):\n        return 0\n    elif var_name.startswith(\"patch_embed\"):\n        return 0\n    elif var_name.startswith(\"rel_pos_bias\"):\n        return num_max_layer - 1\n    elif var_name.startswith(\"blocks\"):\n        layer_id = int(var_name.split('.')[1])\n        return layer_id + 1\n    else:\n        return num_max_layer - 1\n\n\nclass LayerDecayValueAssigner(object):\n    def __init__(self, values):\n        self.values = values\n\n    def get_scale(self, layer_id):\n        return self.values[layer_id]\n\n    def get_layer_id(self, var_name):\n        return get_num_layer_for_vit(var_name, len(self.values))\n\n\ndef get_parameter_groups(model, weight_decay=1e-5, skip_list=(), get_num_layer=None, get_layer_scale=None):\n    parameter_group_names = {}\n    parameter_group_vars = {}\n\n    for name, param in model.named_parameters():\n        if not param.requires_grad:\n            continue  # frozen weights\n        if len(param.shape) == 1 or name.endswith(\".bias\") or name in skip_list:\n            group_name = \"no_decay\"\n            this_weight_decay = 0.\n        else:\n            group_name = \"decay\"\n            this_weight_decay = weight_decay\n        if get_num_layer is not None:\n            layer_id = get_num_layer(name)\n            group_name = \"layer_%d_%s\" % (layer_id, group_name)\n        else:\n            layer_id = None\n\n        if group_name not in parameter_group_names:\n            if get_layer_scale is not None:\n                scale = get_layer_scale(layer_id)\n            else:\n                scale = 1.\n\n            parameter_group_names[group_name] = {\n                \"weight_decay\": this_weight_decay,\n                \"params\": [],\n                \"lr_scale\": scale\n            }\n            parameter_group_vars[group_name] = {\n                \"weight_decay\": this_weight_decay,\n                \"params\": [],\n                \"lr_scale\": scale\n            }\n\n        parameter_group_vars[group_name][\"params\"].append(param)\n        parameter_group_names[group_name][\"params\"].append(name)\n    print(\"Param groups = %s\" % json.dumps(parameter_group_names, indent=2))\n    return list(parameter_group_vars.values())\n\n\ndef create_optimizer(args, model, get_num_layer=None, get_layer_scale=None, filter_bias_and_bn=True, skip_list=None):\n    opt_lower = args.opt.lower()\n    weight_decay = args.weight_decay\n    if weight_decay and filter_bias_and_bn:\n        skip = {}\n        if skip_list is not None:\n            skip = skip_list\n        elif hasattr(model, 'no_weight_decay'):\n            skip = model.no_weight_decay()\n        parameters = get_parameter_groups(model, weight_decay, skip, get_num_layer, get_layer_scale)\n        weight_decay = 0.\n    else:\n        parameters = model.parameters()\n\n    if 'fused' in opt_lower:\n        assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers'\n\n    opt_args = dict(lr=args.lr, weight_decay=weight_decay)\n    if hasattr(args, 'opt_eps') and args.opt_eps is not None:\n        opt_args['eps'] = args.opt_eps\n    if hasattr(args, 'opt_betas') and args.opt_betas is not None:\n        opt_args['betas'] = args.opt_betas\n\n    opt_split = opt_lower.split('_')\n    opt_lower = opt_split[-1]\n    if opt_lower == 'sgd' or opt_lower == 'nesterov':\n        opt_args.pop('eps', None)\n        optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=True, **opt_args)\n    elif opt_lower == 'momentum':\n        opt_args.pop('eps', None)\n        optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=False, **opt_args)\n    elif opt_lower == 'adam':\n        optimizer = optim.Adam(parameters, **opt_args)\n    elif opt_lower == 'adamw':\n        optimizer = optim.AdamW(parameters, **opt_args)\n    elif opt_lower == 'nadam':\n        optimizer = Nadam(parameters, **opt_args)\n    elif opt_lower == 'radam':\n        optimizer = RAdam(parameters, **opt_args)\n    elif opt_lower == 'adamp':\n        optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, **opt_args)\n    elif opt_lower == 'sgdp':\n        optimizer = SGDP(parameters, momentum=args.momentum, nesterov=True, **opt_args)\n    elif opt_lower == 'adadelta':\n        optimizer = optim.Adadelta(parameters, **opt_args)\n    elif opt_lower == 'adafactor':\n        if not args.lr:\n            opt_args['lr'] = None\n        optimizer = Adafactor(parameters, **opt_args)\n    elif opt_lower == 'adahessian':\n        optimizer = Adahessian(parameters, **opt_args)\n    elif opt_lower == 'rmsprop':\n        optimizer = optim.RMSprop(parameters, alpha=0.9, momentum=args.momentum, **opt_args)\n    elif opt_lower == 'rmsproptf':\n        optimizer = RMSpropTF(parameters, alpha=0.9, momentum=args.momentum, **opt_args)\n    elif opt_lower == 'novograd':\n        optimizer = NovoGrad(parameters, **opt_args)\n    elif opt_lower == 'nvnovograd':\n        optimizer = NvNovoGrad(parameters, **opt_args)\n    elif opt_lower == 'fusedsgd':\n        opt_args.pop('eps', None)\n        optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=True, **opt_args)\n    elif opt_lower == 'fusedmomentum':\n        opt_args.pop('eps', None)\n        optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=False, **opt_args)\n    elif opt_lower == 'fusedadam':\n        optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args)\n    elif opt_lower == 'fusedadamw':\n        optimizer = FusedAdam(parameters, adam_w_mode=True, **opt_args)\n    elif opt_lower == 'fusedlamb':\n        optimizer = FusedLAMB(parameters, **opt_args)\n    elif opt_lower == 'fusednovograd':\n        opt_args.setdefault('betas', (0.95, 0.98))\n        optimizer = FusedNovoGrad(parameters, **opt_args)\n    else:\n        assert False and \"Invalid optimizer\"\n        raise ValueError\n\n    if len(opt_split) > 1:\n        if opt_split[0] == 'lookahead':\n            optimizer = Lookahead(optimizer)\n\n    return optimizer\n"
  },
  {
    "path": "furnace/transforms.py",
    "content": "import torch\nimport torchvision.transforms.functional as F\nfrom PIL import Image\nimport warnings\nimport math\nimport random\nimport numpy as np\n\n\nclass ToNumpy:\n\n    def __call__(self, pil_img):\n        np_img = np.array(pil_img, dtype=np.uint8)\n        if np_img.ndim < 3:\n            np_img = np.expand_dims(np_img, axis=-1)\n        np_img = np.rollaxis(np_img, 2)  # HWC to CHW\n        return np_img\n\n\nclass ToTensor:\n\n    def __init__(self, dtype=torch.float32):\n        self.dtype = dtype\n\n    def __call__(self, pil_img):\n        np_img = np.array(pil_img, dtype=np.uint8)\n        if np_img.ndim < 3:\n            np_img = np.expand_dims(np_img, axis=-1)\n        np_img = np.rollaxis(np_img, 2)  # HWC to CHW\n        return torch.from_numpy(np_img).to(dtype=self.dtype)\n\n\n_pil_interpolation_to_str = {\n    Image.NEAREST: 'PIL.Image.NEAREST',\n    Image.BILINEAR: 'PIL.Image.BILINEAR',\n    Image.BICUBIC: 'PIL.Image.BICUBIC',\n    Image.LANCZOS: 'PIL.Image.LANCZOS',\n    Image.HAMMING: 'PIL.Image.HAMMING',\n    Image.BOX: 'PIL.Image.BOX',\n}\n\n\ndef _pil_interp(method):\n    if method == 'bicubic':\n        return Image.BICUBIC\n    elif method == 'lanczos':\n        return Image.LANCZOS\n    elif method == 'hamming':\n        return Image.HAMMING\n    else:\n        # default bilinear, do we want to allow nearest?\n        return Image.BILINEAR\n\n\n_RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC)\n\n\nclass RandomResizedCropAndInterpolationWithTwoPic:\n    \"\"\"Crop the given PIL Image to random size and aspect ratio with random interpolation.\n\n    A crop of random size (default: of 0.08 to 1.0) of the original size and a random\n    aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop\n    is finally resized to given size.\n    This is popularly used to train the Inception networks.\n\n    Args:\n        size: expected output size of each edge\n        scale: range of size of the origin size cropped\n        ratio: range of aspect ratio of the origin aspect ratio cropped\n        interpolation: Default: PIL.Image.BILINEAR\n    \"\"\"\n\n    def __init__(self, size, second_size=None, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.),\n                 interpolation='bilinear', second_interpolation='lanczos'):\n        if isinstance(size, tuple):\n            self.size = size\n        else:\n            self.size = (size, size)\n        if second_size is not None:\n            if isinstance(second_size, tuple):\n                self.second_size = second_size\n            else:\n                self.second_size = (second_size, second_size)\n        else:\n            self.second_size = None\n        if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):\n            warnings.warn(\"range should be of kind (min, max)\")\n\n        if interpolation == 'random':\n            self.interpolation = _RANDOM_INTERPOLATION\n        else:\n            self.interpolation = _pil_interp(interpolation)\n        self.second_interpolation = _pil_interp(second_interpolation)\n        self.scale = scale\n        self.ratio = ratio\n\n    @staticmethod\n    def get_params(img, scale, ratio):\n        \"\"\"Get parameters for ``crop`` for a random sized crop.\n\n        Args:\n            img (PIL Image): Image to be cropped.\n            scale (tuple): range of size of the origin size cropped\n            ratio (tuple): range of aspect ratio of the origin aspect ratio cropped\n\n        Returns:\n            tuple: params (i, j, h, w) to be passed to ``crop`` for a random\n                sized crop.\n        \"\"\"\n        area = img.size[0] * img.size[1]\n\n        for attempt in range(10):\n            target_area = random.uniform(*scale) * area\n            log_ratio = (math.log(ratio[0]), math.log(ratio[1]))\n            aspect_ratio = math.exp(random.uniform(*log_ratio))\n\n            w = int(round(math.sqrt(target_area * aspect_ratio)))\n            h = int(round(math.sqrt(target_area / aspect_ratio)))\n\n            if w <= img.size[0] and h <= img.size[1]:\n                i = random.randint(0, img.size[1] - h)\n                j = random.randint(0, img.size[0] - w)\n                return i, j, h, w\n\n        # Fallback to central crop\n        in_ratio = img.size[0] / img.size[1]\n        if in_ratio < min(ratio):\n            w = img.size[0]\n            h = int(round(w / min(ratio)))\n        elif in_ratio > max(ratio):\n            h = img.size[1]\n            w = int(round(h * max(ratio)))\n        else:  # whole image\n            w = img.size[0]\n            h = img.size[1]\n        i = (img.size[1] - h) // 2\n        j = (img.size[0] - w) // 2\n        return i, j, h, w\n\n    def __call__(self, img):\n        \"\"\"\n        Args:\n            img (PIL Image): Image to be cropped and resized.\n\n        Returns:\n            PIL Image: Randomly cropped and resized image.\n        \"\"\"\n        i, j, h, w = self.get_params(img, self.scale, self.ratio)\n        if isinstance(self.interpolation, (tuple, list)):\n            interpolation = random.choice(self.interpolation)\n        else:\n            interpolation = self.interpolation\n        if self.second_size is None:\n            return F.resized_crop(img, i, j, h, w, self.size, interpolation)\n        else:\n            return F.resized_crop(img, i, j, h, w, self.size, interpolation), \\\n                   F.resized_crop(img, i, j, h, w, self.second_size, self.second_interpolation)\n\n    def __repr__(self):\n        if isinstance(self.interpolation, (tuple, list)):\n            interpolate_str = ' '.join([_pil_interpolation_to_str[x] for x in self.interpolation])\n        else:\n            interpolate_str = _pil_interpolation_to_str[self.interpolation]\n        format_string = self.__class__.__name__ + '(size={0}'.format(self.size)\n        format_string += ', scale={0}'.format(tuple(round(s, 4) for s in self.scale))\n        format_string += ', ratio={0}'.format(tuple(round(r, 4) for r in self.ratio))\n        format_string += ', interpolation={0}'.format(interpolate_str)\n        if self.second_size is not None:\n            format_string += ', second_size={0}'.format(self.second_size)\n            format_string += ', second_interpolation={0}'.format(_pil_interpolation_to_str[self.second_interpolation])\n        format_string += ')'\n        return format_string\n"
  },
  {
    "path": "furnace/utils.py",
    "content": "# --------------------------------------------------------\n# BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254)\n# Github source: https://github.com/microsoft/unilm/tree/master/beit\n# Copyright (c) 2021 Microsoft\n# Licensed under The MIT License [see LICENSE for details]\n# By Hangbo Bao\n# Based on timm, DINO and DeiT code bases\n# https://github.com/rwightman/pytorch-image-models/tree/master/timm\n# https://github.com/facebookresearch/deit\n# https://github.com/facebookresearch/dino\n# --------------------------------------------------------'\nimport io\nimport os\nimport math\nimport time\nimport json\nfrom collections import defaultdict, deque\nimport datetime\nimport numpy as np\nfrom timm.utils import get_state_dict\n\nfrom pathlib import Path\n\nimport torch\nimport torch.distributed as dist\nfrom torch._six import inf\nfrom models.modeling_discrete_vae import Dalle_VAE, DiscreteVAE, VGGAN\n\nfrom tensorboardX import SummaryWriter\nimport torch.nn.functional as F\nfrom torch.nn.modules.batchnorm import _NormBase\n\n\nclass SmoothedValue(object):\n    \"\"\"Track a series of values and provide access to smoothed values over a\n    window or the global series average.\n    \"\"\"\n\n    def __init__(self, window_size=20, fmt=None):\n        if fmt is None:\n            fmt = \"{median:.4f} ({global_avg:.4f})\"\n        self.deque = deque(maxlen=window_size)\n        self.total = 0.0\n        self.count = 0\n        self.fmt = fmt\n\n    def update(self, value, n=1):\n        self.deque.append(value)\n        self.count += n\n        self.total += value * n\n\n    def synchronize_between_processes(self):\n        \"\"\"\n        Warning: does not synchronize the deque!\n        \"\"\"\n        if not is_dist_avail_and_initialized():\n            return\n        t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')\n        dist.barrier()\n        dist.all_reduce(t)\n        t = t.tolist()\n        self.count = int(t[0])\n        self.total = t[1]\n\n    @property\n    def median(self):\n        d = torch.tensor(list(self.deque))\n        return d.median().item()\n\n    @property\n    def avg(self):\n        d = torch.tensor(list(self.deque), dtype=torch.float32)\n        return d.mean().item()\n\n    @property\n    def global_avg(self):\n        return self.total / self.count\n\n    @property\n    def max(self):\n        return max(self.deque)\n\n    @property\n    def value(self):\n        return self.deque[-1]\n\n    def __str__(self):\n        return self.fmt.format(\n            median=self.median,\n            avg=self.avg,\n            global_avg=self.global_avg,\n            max=self.max,\n            value=self.value)\n\n\nclass MetricLogger(object):\n    def __init__(self, delimiter=\"\\t\"):\n        self.meters = defaultdict(SmoothedValue)\n        self.delimiter = delimiter\n\n    def update(self, **kwargs):\n        for k, v in kwargs.items():\n            if v is None:\n                continue\n            if isinstance(v, torch.Tensor):\n                v = v.item()\n            assert isinstance(v, (float, int))\n            self.meters[k].update(v)\n\n    def __getattr__(self, attr):\n        if attr in self.meters:\n            return self.meters[attr]\n        if attr in self.__dict__:\n            return self.__dict__[attr]\n        raise AttributeError(\"'{}' object has no attribute '{}'\".format(\n            type(self).__name__, attr))\n\n    def __str__(self):\n        loss_str = []\n        for name, meter in self.meters.items():\n            loss_str.append(\n                \"{}: {}\".format(name, str(meter))\n            )\n        return self.delimiter.join(loss_str)\n\n    def synchronize_between_processes(self):\n        for meter in self.meters.values():\n            meter.synchronize_between_processes()\n\n    def add_meter(self, name, meter):\n        self.meters[name] = meter\n\n    def log_every(self, iterable, print_freq, header=None):\n        i = 0\n        if not header:\n            header = ''\n        start_time = time.time()\n        end = time.time()\n        iter_time = SmoothedValue(fmt='{avg:.4f}')\n        data_time = SmoothedValue(fmt='{avg:.4f}')\n        space_fmt = ':' + str(len(str(len(iterable)))) + 'd'\n        log_msg = [\n            header,\n            '[{0' + space_fmt + '}/{1}]',\n            'eta: {eta}',\n            '{meters}',\n            'time: {time}',\n            'data: {data}'\n        ]\n        if torch.cuda.is_available():\n            log_msg.append('max mem: {memory:.0f}')\n        log_msg = self.delimiter.join(log_msg)\n        MB = 1024.0 * 1024.0\n        for obj in iterable:\n            data_time.update(time.time() - end)\n            yield obj\n            iter_time.update(time.time() - end)\n            if i % print_freq == 0 or i == len(iterable) - 1:\n                eta_seconds = iter_time.global_avg * (len(iterable) - i)\n                eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))\n                if torch.cuda.is_available():\n                    print(log_msg.format(\n                        i, len(iterable), eta=eta_string,\n                        meters=str(self),\n                        time=str(iter_time), data=str(data_time),\n                        memory=torch.cuda.max_memory_allocated() / MB))\n                else:\n                    print(log_msg.format(\n                        i, len(iterable), eta=eta_string,\n                        meters=str(self),\n                        time=str(iter_time), data=str(data_time)))\n            i += 1\n            end = time.time()\n        total_time = time.time() - start_time\n        total_time_str = str(datetime.timedelta(seconds=int(total_time)))\n        print('{} Total time: {} ({:.4f} s / it)'.format(\n            header, total_time_str, total_time / len(iterable)))\n\n\nclass TensorboardLogger(object):\n    def __init__(self, log_dir):\n        self.writer = SummaryWriter(logdir=log_dir)\n        self.step = 0\n\n    def set_step(self, step=None):\n        if step is not None:\n            self.step = step\n        else:\n            self.step += 1\n\n    def update(self, head='scalar', step=None, **kwargs):\n        for k, v in kwargs.items():\n            if v is None:\n                continue\n            if isinstance(v, torch.Tensor):\n                v = v.item()\n            assert isinstance(v, (float, int))\n            self.writer.add_scalar(head + \"/\" + k, v, self.step if step is None else step)\n\n    def flush(self):\n        self.writer.flush()\n\n\ndef _load_checkpoint_for_ema(model_ema, checkpoint):\n    \"\"\"\n    Workaround for ModelEma._load_checkpoint to accept an already-loaded object\n    \"\"\"\n    mem_file = io.BytesIO()\n    torch.save(checkpoint, mem_file)\n    mem_file.seek(0)\n    model_ema._load_checkpoint(mem_file)\n\ndef setup_for_distributed_each_gpu(rank):\n    import builtins as __builtin__\n    builtin_print = __builtin__.print\n\n    def print(*args, **kwargs):\n        builtin_print('rank is: ', rank, end=' ')\n        now = datetime.datetime.now().time()\n        builtin_print('[{}] '.format(now), end='')  # print with time stamp\n        builtin_print(*args, **kwargs)\n\n    __builtin__.print = print\n\ndef setup_for_distributed(is_master):\n    \"\"\"\n    This function disables printing when not in master process\n    \"\"\"\n    import builtins as __builtin__\n    builtin_print = __builtin__.print\n\n    def print(*args, **kwargs):\n        force = kwargs.pop('force', False)\n        if is_master or force:\n            now = datetime.datetime.now().time()\n            builtin_print('[{}] '.format(now), end='')  # print with time stamp\n            builtin_print(*args, **kwargs)\n\n    __builtin__.print = print\n\n\ndef is_dist_avail_and_initialized():\n    if not dist.is_available():\n        return False\n    if not dist.is_initialized():\n        return False\n    return True\n\n\ndef get_world_size():\n    if not is_dist_avail_and_initialized():\n        return 1\n    return dist.get_world_size()\n\n\ndef get_rank():\n    if not is_dist_avail_and_initialized():\n        return 0\n    return dist.get_rank()\n\n\ndef is_main_process():\n    return get_rank() == 0\n\n\ndef save_on_master(*args, **kwargs):\n    if is_main_process():\n        torch.save(*args, **kwargs)\n\n\ndef init_distributed_mode(args):\n    if args.dist_on_itp:\n        args.rank = int(os.environ['OMPI_COMM_WORLD_RANK'])\n        args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])\n        args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])\n        args.dist_url = \"tcp://%s:%s\" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT'])\n        os.environ['LOCAL_RANK'] = str(args.gpu)\n        os.environ['RANK'] = str(args.rank)\n        os.environ['WORLD_SIZE'] = str(args.world_size)\n        # [\"RANK\", \"WORLD_SIZE\", \"MASTER_ADDR\", \"MASTER_PORT\", \"LOCAL_RANK\"]\n    elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:\n        args.rank = int(os.environ[\"RANK\"])\n        args.world_size = int(os.environ['WORLD_SIZE'])\n        args.gpu = int(os.environ['LOCAL_RANK'])\n    elif 'SLURM_PROCID' in os.environ:\n        args.rank = int(os.environ['SLURM_PROCID'])\n        args.gpu = args.rank % torch.cuda.device_count()\n    else:\n        print('Not using distributed mode')\n        args.distributed = False\n        return\n\n    args.distributed = True\n\n    torch.cuda.set_device(args.gpu)\n    args.dist_backend = 'nccl'\n    print('| distributed init (rank {}): {}, gpu {}'.format(\n        args.rank, args.dist_url, args.gpu), flush=True)\n    torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,\n                                         world_size=args.world_size, rank=args.rank)\n    torch.distributed.barrier()\n    if not args.enable_multi_print:\n        setup_for_distributed(args.rank == 0)\n    else:\n        setup_for_distributed_each_gpu(args.rank)\n\n\ndef load_state_dict(model, state_dict, prefix='', ignore_missing=\"relative_position_index\"):\n    missing_keys = []\n    unexpected_keys = []\n    error_msgs = []\n    # copy state_dict so _load_from_state_dict can modify it\n    metadata = getattr(state_dict, '_metadata', None)\n    state_dict = state_dict.copy()\n    if metadata is not None:\n        state_dict._metadata = metadata\n\n    def load(module, prefix=''):\n        local_metadata = {} if metadata is None else metadata.get(\n            prefix[:-1], {})\n        module._load_from_state_dict(\n            state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)\n        for name, child in module._modules.items():\n            if child is not None:\n                load(child, prefix + name + '.')\n\n    load(model, prefix=prefix)\n\n    warn_missing_keys = []\n    ignore_missing_keys = []\n    for key in missing_keys:\n        keep_flag = True\n        for ignore_key in ignore_missing.split('|'):\n            if ignore_key in key:\n                keep_flag = False\n                break\n        if keep_flag:\n            warn_missing_keys.append(key)\n        else:\n            ignore_missing_keys.append(key)\n\n    missing_keys = warn_missing_keys\n\n    if len(missing_keys) > 0:\n        print(\"Weights of {} not initialized from pretrained model: {}\".format(\n            model.__class__.__name__, missing_keys))\n    if len(unexpected_keys) > 0:\n        print(\"Weights from pretrained model not used in {}: {}\".format(\n            model.__class__.__name__, unexpected_keys))\n    if len(ignore_missing_keys) > 0:\n        print(\"Ignored weights of {} not initialized from pretrained model: {}\".format(\n            model.__class__.__name__, ignore_missing_keys))\n    if len(error_msgs) > 0:\n        print('\\n'.join(error_msgs))\n\n\nclass NativeScalerWithGradNormCount:\n    state_dict_key = \"amp_scaler\"\n\n    def __init__(self):\n        self._scaler = torch.cuda.amp.GradScaler()\n\n    def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True):\n        self._scaler.scale(loss).backward(create_graph=create_graph)\n        if update_grad:\n            if clip_grad is not None:\n                assert parameters is not None\n                self._scaler.unscale_(optimizer)  # unscale the gradients of optimizer's assigned params in-place\n                norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)\n            else:\n                self._scaler.unscale_(optimizer)\n                norm = get_grad_norm_(parameters)\n            self._scaler.step(optimizer)\n            self._scaler.update()\n        else:\n            norm = None\n        return norm\n\n    def state_dict(self):\n        return self._scaler.state_dict()\n\n    def load_state_dict(self, state_dict):\n        self._scaler.load_state_dict(state_dict)\n\n\ndef get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:\n    if isinstance(parameters, torch.Tensor):\n        parameters = [parameters]\n    parameters = [p for p in parameters if p.grad is not None]\n    norm_type = float(norm_type)\n    if len(parameters) == 0:\n        return torch.tensor(0.)\n    device = parameters[0].grad.device\n    if norm_type == inf:\n        total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)\n    else:\n        total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)\n    return total_norm\n\n\ndef cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0,\n                     start_warmup_value=0, warmup_steps=-1):\n    warmup_schedule = np.array([])\n    warmup_iters = warmup_epochs * niter_per_ep\n    if warmup_steps > 0:\n        warmup_iters = warmup_steps\n    print(\"Set warmup steps = %d\" % warmup_iters)\n    if warmup_epochs > 0:\n        warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)\n\n    iters = np.arange(epochs * niter_per_ep - warmup_iters)\n    schedule = np.array(\n        [final_value + 0.5 * (base_value - final_value) * (1 + math.cos(math.pi * i / (len(iters)))) for i in iters])\n\n    schedule = np.concatenate((warmup_schedule, schedule))\n\n    assert len(schedule) == epochs * niter_per_ep\n    return schedule\n\n\ndef save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler, model_ema=None, exp_name=None):\n    output_dir = Path(args.output_dir)\n    epoch_name = str(epoch)\n    if loss_scaler is not None:\n        if exp_name is not None:\n            checkpoint_paths = [output_dir / ('{}_checkpoint-{}.pth'.format(exp_name, epoch_name))]\n        else:\n            checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)]\n        for checkpoint_path in checkpoint_paths:\n            to_save_state_dict = model_without_ddp.state_dict()\n            # all_keys = list(state_dict.keys())\n                \n            for key in list(to_save_state_dict.keys()):\n                if key.startswith('teacher.'):\n                    to_save_state_dict.pop(key)\n\n            to_save = {\n                'model': to_save_state_dict,\n                'optimizer': optimizer.state_dict(),\n                'epoch': epoch,\n                'scaler': loss_scaler.state_dict(),\n                'args': args,\n            }\n\n            if model_ema is not None:\n                to_save['model_ema'] = get_state_dict(model_ema)\n\n            save_on_master(to_save, checkpoint_path)\n    else:\n        client_state = {'epoch': epoch}\n        if model_ema is not None:\n            client_state['model_ema'] = get_state_dict(model_ema)\n        if exp_name is not None:\n            model.save_checkpoint(save_dir=args.output_dir, tag=\"{}_checkpoint-{}\".format(exp_name, epoch_name), client_state=client_state)\n        else:\n            model.save_checkpoint(save_dir=args.output_dir, tag=\"checkpoint-%s\" % epoch_name, client_state=client_state)\n\n\ndef auto_load_model(args, model, model_without_ddp, optimizer, loss_scaler, model_ema=None):\n    output_dir = Path(args.output_dir)\n    if loss_scaler is not None:\n        # torch.amp\n        if args.auto_resume and len(args.resume) == 0:\n            import glob\n            all_checkpoints = glob.glob(os.path.join(output_dir, 'checkpoint-*.pth'))\n            latest_ckpt = -1\n            for ckpt in all_checkpoints:\n                t = ckpt.split('-')[-1].split('.')[0]\n                if t.isdigit():\n                    latest_ckpt = max(int(t), latest_ckpt)\n            if latest_ckpt >= 0:\n                args.resume = os.path.join(output_dir, 'checkpoint-%d.pth' % latest_ckpt)\n            print(\"Auto resume checkpoint: %s\" % args.resume)\n\n        if args.resume:\n            if args.resume.startswith('https'):\n                checkpoint = torch.hub.load_state_dict_from_url(\n                    args.resume, map_location='cpu', check_hash=True)\n            else:\n                checkpoint = torch.load(args.resume, map_location='cpu')\n            \n            # handle ema model\n            need_state_dict = model_without_ddp.state_dict()\n            need_ema = False\n            for key in need_state_dict.keys():\n                if 'teacher' in key:\n                    need_ema = True\n                    break\n                \n            checkpoint_model = checkpoint['model']\n\n            if need_ema:\n                all_keys = list(checkpoint_model.keys())            \n                all_keys = [key for key in all_keys if key.startswith('encoder.')]\n                for key in all_keys:\n                    new_key = key.replace('encoder.','teacher.')\n                    checkpoint_model[new_key] = checkpoint_model[key]\n\n            model_without_ddp.load_state_dict(checkpoint_model)\n            print(\"Resume checkpoint %s\" % args.resume)\n            if 'optimizer' in checkpoint and 'epoch' in checkpoint:\n                optimizer.load_state_dict(checkpoint['optimizer'])\n                args.start_epoch = checkpoint['epoch'] + 1\n                if hasattr(args, 'model_ema') and args.model_ema:\n                    _load_checkpoint_for_ema(model_ema, checkpoint['model_ema'])\n                if 'scaler' in checkpoint:\n                    loss_scaler.load_state_dict(checkpoint['scaler'])\n                print(\"With optim & sched!\")\n    else:\n        # deepspeed, only support '--auto_resume'.\n        if args.auto_resume:\n            import glob\n            all_checkpoints = glob.glob(os.path.join(output_dir, 'checkpoint-*'))\n            latest_ckpt = -1\n            for ckpt in all_checkpoints:\n                t = ckpt.split('-')[-1].split('.')[0]\n                if t.isdigit():\n                    latest_ckpt = max(int(t), latest_ckpt)\n            if latest_ckpt >= 0:\n                args.resume = os.path.join(output_dir, 'checkpoint-%d' % latest_ckpt)\n                print(\"Auto resume checkpoint: %d\" % latest_ckpt)\n                _, client_states = model.load_checkpoint(args.output_dir, tag='checkpoint-%d' % latest_ckpt)\n                args.start_epoch = client_states['epoch'] + 1\n                if model_ema is not None:\n                    if args.model_ema:\n                        _load_checkpoint_for_ema(model_ema, client_states['model_ema'])\n\n\ndef create_d_vae(weight_path, d_vae_type, image_size, device, args=None):\n    if d_vae_type == \"dall-e\":\n        return get_dalle_vae(weight_path, image_size, device)\n    if d_vae_type == \"vqgan_gumbel_f8_8192\":\n        return get_vqgan_gumbel_f8_8192(weight_path, image_size, device)\n    elif d_vae_type == \"customized\":\n        return get_d_vae(weight_path, image_size, device, args)\n    elif d_vae_type == \"to_tensor\":\n        return None\n    else:\n        raise NotImplementedError()\n\ndef get_vqgan_gumbel_f8_8192(weight_path, image_size, device):\n    with torch.no_grad():\n        vqgan = VGGAN(image_size)\n        vqgan.load_model(weight_path, device)\n\n        return vqgan \n\n\ndef get_dalle_vae(weight_path, image_size, device):\n    vae = Dalle_VAE(image_size)\n    vae.load_model(model_dir=weight_path, device=device)\n    return vae\n\n\ndef get_d_vae(weight_path, image_size, device, args):\n    NUM_TOKENS = 8192\n    NUM_LAYERS = args.dvae_num_layers\n    EMB_DIM = 512\n    HID_DIM = 256\n\n    state_dict = torch.load(weight_path, map_location=\"cpu\")[\"model\"]\n\n    model = DiscreteVAE(\n        image_size=image_size,\n        num_layers=NUM_LAYERS,\n        num_tokens=NUM_TOKENS,\n        codebook_dim=EMB_DIM,\n        hidden_dim=HID_DIM,\n    ).to(device)\n\n    model.load_state_dict(state_dict)\n    return model\n\n\ndef create_ds_config(args):\n    args.deepspeed_config = os.path.join(args.output_dir, \"deepspeed_config.json\")\n    with open(args.deepspeed_config, mode=\"w\") as writer:\n        ds_config = {\n            \"train_batch_size\": args.batch_size * args.update_freq * get_world_size(),\n            \"train_micro_batch_size_per_gpu\": args.batch_size,\n            \"steps_per_print\": 1000,\n            \"optimizer\": {\n                \"type\": \"Adam\",\n                \"adam_w_mode\": True,\n                \"params\": {\n                    \"lr\": args.lr,\n                    \"weight_decay\": args.weight_decay,\n                    \"bias_correction\": True,\n                    \"betas\": [\n                        0.9,\n                        0.999\n                    ],\n                    \"eps\": 1e-8\n                }\n            },\n            \"fp16\": {\n                \"enabled\": True,\n                \"loss_scale\": 0,\n                \"initial_scale_power\": 7,\n                \"loss_scale_window\": 128\n            }\n        }\n\n        writer.write(json.dumps(ds_config, indent=2))\n\n\nclass LP_BatchNorm(_NormBase):\n    \"\"\" A variant used in linear probing.\n    To freeze parameters (normalization operator specifically), model set to eval mode during linear probing.\n    According to paper, an extra BN is used on the top of encoder to calibrate the feature magnitudes.\n    In addition to self.training, we set another flag in this implement to control BN's behavior to train in eval mode.\n    \"\"\"\n\n    def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,\n                 track_running_stats=True):\n        super(LP_BatchNorm, self).__init__(\n            num_features, eps, momentum, affine, track_running_stats)\n\n    def _check_input_dim(self, input):\n        if input.dim() != 2 and input.dim() != 3:\n            raise ValueError('expected 2D or 3D input (got {}D input)'\n                             .format(input.dim()))\n\n    def forward(self, input, is_train):\n        \"\"\"\n        We use is_train instead of self.training.\n        \"\"\"\n        self._check_input_dim(input)\n        # exponential_average_factor is set to self.momentum\n        # (when it is available) only so that it gets updated\n        # in ONNX graph when this node is exported to ONNX.\n        if self.momentum is None:\n            exponential_average_factor = 0.0\n        else:\n            exponential_average_factor = self.momentum\n\n        # if self.training and self.track_running_stats:\n        if is_train and self.track_running_stats:\n            if self.num_batches_tracked is not None:  # type: ignore\n                self.num_batches_tracked = self.num_batches_tracked + 1  # type: ignore\n                if self.momentum is None:  # use cumulative moving average\n                    exponential_average_factor = 1.0 / float(self.num_batches_tracked)\n                else:  # use exponential moving average\n                    exponential_average_factor = self.momentum\n\n        r\"\"\"\n        Decide whether the mini-batch stats should be used for normalization rather than the buffers.\n        Mini-batch stats are used in training mode, and in eval mode when buffers are None.\n        \"\"\"\n        if is_train:\n            bn_training = True\n        else:\n            bn_training = (self.running_mean is None) and (self.running_var is None)\n\n        r\"\"\"\n        Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be\n        passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are\n        used for normalization (i.e. in eval mode when buffers are not None).\n        \"\"\"\n        assert self.running_mean is None or isinstance(self.running_mean, torch.Tensor)\n        assert self.running_var is None or isinstance(self.running_var, torch.Tensor)\n        return F.batch_norm(\n            input,\n            # If buffers are not to be tracked, ensure that they won't be updated\n            self.running_mean if not is_train or self.track_running_stats else None,\n            self.running_var if not is_train or self.track_running_stats else None,\n            self.weight, self.bias, bn_training, exponential_average_factor, self.eps)\n"
  },
  {
    "path": "linear_util/crop.py",
    "content": "import math\n\nimport torch\n\nfrom torchvision import transforms\nfrom torchvision.transforms import functional as F\n\n\nclass RandomResizedCrop(transforms.RandomResizedCrop):\n    \"\"\"\n    RandomResizedCrop for matching TF/TPU implementation: no for-loop is used.\n    This may lead to results different with torchvision's version.\n    Following BYOL's TF code:\n    https://github.com/deepmind/deepmind-research/blob/master/byol/utils/dataset.py#L206\n    \"\"\"\n    @staticmethod\n    def get_params(img, scale, ratio):\n        width, height = F._get_image_size(img)\n        area = height * width\n\n        target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item()\n        log_ratio = torch.log(torch.tensor(ratio))\n        aspect_ratio = torch.exp(\n            torch.empty(1).uniform_(log_ratio[0], log_ratio[1])\n        ).item()\n\n        w = int(round(math.sqrt(target_area * aspect_ratio)))\n        h = int(round(math.sqrt(target_area / aspect_ratio)))\n\n        w = min(w, width)\n        h = min(h, height)\n\n        i = torch.randint(0, height - h + 1, size=(1,)).item()\n        j = torch.randint(0, width - w + 1, size=(1,)).item()\n\n        return i, j, h, w"
  },
  {
    "path": "linear_util/datasets.py",
    "content": "import os\nimport PIL\n\nfrom torchvision import datasets, transforms\n\nfrom timm.data import create_transform\nfrom timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD\nfrom furnace.masking_generator import MaskingGenerator, RandomMaskingGenerator\n\nclass DataAugmentationMySelf(object):\n    def __init__(self, args):\n        self.patch_transform =   transforms.Compose([\n            transforms.RandomResizedCrop(args.input_size, scale=(0.2, 1.0), interpolation=3),  # 3 is bicubic\n            transforms.RandomHorizontalFlip(),\n            transforms.ToTensor(),\n            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])\n        \n        if args.mask_generator == 'block':\n            self.masked_position_generator = MaskingGenerator(\n                args.window_size, num_masking_ratio=args.mask_ratio,\n                max_num_patches=args.max_mask_patches_per_block,\n                min_num_patches=args.min_mask_patches_per_block,\n            )\n        elif args.mask_generator == 'random':\n            self.masked_position_generator = RandomMaskingGenerator(\n                args.window_size, ratio_masking_patches=args.mask_ratio\n            )\n\n    def __call__(self, image):\n        return self.patch_transform(image), self.masked_position_generator()\n\n    def __repr__(self):\n        repr = \"(DataAugmentationMySelf,\\n\"\n        repr += \"  patch_transform = %s,\\n\" % str(self.patch_transform)\n        repr += \"  Masked position generator = %s,\\n\" % str(self.masked_position_generator)\n        repr += \")\"\n        return repr\n\ndef build_dataset(is_train, args):\n    transform = build_transform(is_train, args)\n\n    root = os.path.join(args.data_path, 'train' if is_train else 'val')\n    dataset = datasets.ImageFolder(root, transform=transform)\n\n    print(dataset)\n\n    return dataset\n\ndef build_dataset_finetune(is_train, args):\n    transform = build_transform_finetune(is_train, args)\n\n    root = os.path.join(args.data_path, 'train' if is_train else 'val')\n    dataset = datasets.ImageFolder(root, transform=transform)\n\n    print(dataset)\n\n    return dataset\n\ndef build_transform_finetune(is_train, args):\n    mean = IMAGENET_DEFAULT_MEAN\n    std = IMAGENET_DEFAULT_STD\n    # train transform\n    if is_train:\n        # this should always dispatch to transforms_imagenet_train\n        transform = create_transform(\n            input_size=args.input_size,\n            is_training=True,\n            color_jitter=args.color_jitter,\n            auto_augment=args.aa,\n            interpolation='bicubic',\n            re_prob=args.reprob,\n            re_mode=args.remode,\n            re_count=args.recount,\n            mean=mean,\n            std=std,\n        )\n        return transform\n\n    # eval transform\n    t = []\n    if args.input_size <= 224:\n        crop_pct = 224 / 256\n    else:\n        crop_pct = 1.0\n    size = int(args.input_size / crop_pct)\n    t.append(\n        transforms.Resize(size, interpolation=PIL.Image.BICUBIC),  # to maintain same ratio w.r.t. 224 images\n    )\n    t.append(transforms.CenterCrop(args.input_size))\n\n    t.append(transforms.ToTensor())\n    t.append(transforms.Normalize(mean, std))\n    return transforms.Compose(t)\n\ndef build_transform(is_train, args):\n    mean = IMAGENET_DEFAULT_MEAN\n    std = IMAGENET_DEFAULT_STD\n    # train transform\n    if is_train:\n        # this should always dispatch to transforms_imagenet_train\n        transform = create_transform(\n            input_size=args.input_size,\n            is_training=True,\n            color_jitter=args.color_jitter,\n            auto_augment=args.aa,\n            interpolation='bicubic',\n            re_prob=args.reprob,\n            re_mode=args.remode,\n            re_count=args.recount,\n            mean=mean,\n            std=std,\n        )\n        return DataAugmentationMySelf(args, transform)\n\n    # eval transform\n    t = []\n    if args.input_size <= 224:\n        crop_pct = 224 / 256\n    else:\n        crop_pct = 1.0\n    size = int(args.input_size / crop_pct)\n    t.append(\n        transforms.Resize(size, interpolation=PIL.Image.BICUBIC),  # to maintain same ratio w.r.t. 224 images\n    )\n    t.append(transforms.CenterCrop(args.input_size))\n\n    t.append(transforms.ToTensor())\n    t.append(transforms.Normalize(mean, std))\n    return transforms.Compose(t)\n"
  },
  {
    "path": "linear_util/engine_finetune.py",
    "content": "import math\nimport sys\nfrom typing import Iterable, Optional\n\nimport torch\n\nfrom timm.data import Mixup\nfrom timm.utils import accuracy\n\nimport linear_util.misc as misc\nimport linear_util.lr_sched as lr_sched\n\n\ndef train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,\n                    data_loader: Iterable, optimizer: torch.optim.Optimizer,\n                    device: torch.device, epoch: int, loss_scaler, max_norm: float = 0,\n                    mixup_fn: Optional[Mixup] = None, log_writer=None,\n                    args=None):\n    model.train(True)\n    metric_logger = misc.MetricLogger(delimiter=\"  \")\n    metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))\n    header = 'Epoch: [{}]'.format(epoch)\n    print_freq = 20\n\n    accum_iter = args.accum_iter\n\n    optimizer.zero_grad()\n\n    if log_writer is not None:\n        print('log_dir: {}'.format(log_writer.log_dir))\n\n    for data_iter_step, (samples, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):\n\n        # we use a per iteration (instead of per epoch) lr scheduler\n        if data_iter_step % accum_iter == 0:\n            lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args)\n\n        samples = samples.to(device, non_blocking=True)\n        targets = targets.to(device, non_blocking=True)\n\n        if mixup_fn is not None:\n            samples, targets = mixup_fn(samples, targets)\n\n        with torch.cuda.amp.autocast():\n            outputs = model(samples)\n            loss = criterion(outputs, targets)\n\n        loss_value = loss.item()\n\n        if not math.isfinite(loss_value):\n            print(\"Loss is {}, stopping training\".format(loss_value))\n            sys.exit(1)\n\n        loss /= accum_iter\n        loss_scaler(loss, optimizer, clip_grad=max_norm,\n                    parameters=model.parameters(), create_graph=False,\n                    update_grad=(data_iter_step + 1) % accum_iter == 0)\n        if (data_iter_step + 1) % accum_iter == 0:\n            optimizer.zero_grad()\n\n        torch.cuda.synchronize()\n\n        metric_logger.update(loss=loss_value)\n        min_lr = 10.\n        max_lr = 0.\n        for group in optimizer.param_groups:\n            min_lr = min(min_lr, group[\"lr\"])\n            max_lr = max(max_lr, group[\"lr\"])\n\n        metric_logger.update(lr=max_lr)\n\n        loss_value_reduce = misc.all_reduce_mean(loss_value)\n        if log_writer is not None and (data_iter_step + 1) % accum_iter == 0:\n            \"\"\" We use epoch_1000x as the x-axis in tensorboard.\n            This calibrates different curves when batch size changes.\n            \"\"\"\n            epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000)\n            log_writer.add_scalar('loss', loss_value_reduce, epoch_1000x)\n            log_writer.add_scalar('lr', max_lr, epoch_1000x)\n\n    # gather the stats from all processes\n    metric_logger.synchronize_between_processes()\n    print(\"Averaged stats:\", metric_logger)\n    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}\n\n\n@torch.no_grad()\ndef evaluate(data_loader, model, device):\n    criterion = torch.nn.CrossEntropyLoss()\n\n    metric_logger = misc.MetricLogger(delimiter=\"  \")\n    header = 'Test:'\n\n    # switch to evaluation mode\n    model.eval()\n\n    for batch in metric_logger.log_every(data_loader, 10, header):\n        images = batch[0]\n        target = batch[-1]\n        images = images.to(device, non_blocking=True)\n        target = target.to(device, non_blocking=True)\n\n        # compute output\n        with torch.cuda.amp.autocast():\n            output = model(images)\n            loss = criterion(output, target)\n\n        acc1, acc5 = accuracy(output, target, topk=(1, 5))\n\n        batch_size = images.shape[0]\n        metric_logger.update(loss=loss.item())\n        metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)\n        metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)\n    # gather the stats from all processes\n    metric_logger.synchronize_between_processes()\n    print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'\n          .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss))\n\n    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}"
  },
  {
    "path": "linear_util/lars.py",
    "content": "import torch\n\n\nclass LARS(torch.optim.Optimizer):\n    \"\"\"\n    LARS optimizer, no rate scaling or weight decay for parameters <= 1D.\n    \"\"\"\n    def __init__(self, params, lr=0, weight_decay=0, momentum=0.9, trust_coefficient=0.001):\n        defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, trust_coefficient=trust_coefficient)\n        super().__init__(params, defaults)\n\n    @torch.no_grad()\n    def step(self):\n        for g in self.param_groups:\n            for p in g['params']:\n                dp = p.grad\n\n                if dp is None:\n                    continue\n\n                if p.ndim > 1: # if not normalization gamma/beta or bias\n                    dp = dp.add(p, alpha=g['weight_decay'])\n                    param_norm = torch.norm(p)\n                    update_norm = torch.norm(dp)\n                    one = torch.ones_like(param_norm)\n                    q = torch.where(param_norm > 0.,\n                                    torch.where(update_norm > 0,\n                                    (g['trust_coefficient'] * param_norm / update_norm), one),\n                                    one)\n                    dp = dp.mul(q)\n\n                param_state = self.state[p]\n                if 'mu' not in param_state:\n                    param_state['mu'] = torch.zeros_like(p)\n                mu = param_state['mu']\n                mu.mul_(g['momentum']).add_(dp)\n                p.add_(mu, alpha=-g['lr'])"
  },
  {
    "path": "linear_util/lr_decay.py",
    "content": "import json\n\n\ndef param_groups_lrd(model, weight_decay=0.05, no_weight_decay_list=[], layer_decay=.75):\n    param_group_names = {}\n    param_groups = {}\n\n    num_layers = len(model.blocks) + 1\n\n    layer_scales = list(layer_decay ** (num_layers - i) for i in range(num_layers + 1))\n\n    for n, p in model.named_parameters():\n        if not p.requires_grad:\n            continue\n\n        # no decay: all 1D parameters and model specific ones\n        if p.ndim == 1 or n in no_weight_decay_list:\n            g_decay = \"no_decay\"\n            this_decay = 0.\n        else:\n            g_decay = \"decay\"\n            this_decay = weight_decay\n            \n        layer_id = get_layer_id_for_vit(n, num_layers)\n        group_name = \"layer_%d_%s\" % (layer_id, g_decay)\n\n        if group_name not in param_group_names:\n            this_scale = layer_scales[layer_id]\n\n            param_group_names[group_name] = {\n                \"lr_scale\": this_scale,\n                \"weight_decay\": this_decay,\n                \"params\": [],\n            }\n            param_groups[group_name] = {\n                \"lr_scale\": this_scale,\n                \"weight_decay\": this_decay,\n                \"params\": [],\n            }\n\n        param_group_names[group_name][\"params\"].append(n)\n        param_groups[group_name][\"params\"].append(p)\n\n    # print(\"parameter groups: \\n%s\" % json.dumps(param_group_names, indent=2))\n\n    return list(param_groups.values())\n\n\ndef get_layer_id_for_vit(name, num_layers):\n    if name in ['cls_token', 'pos_embed']:\n        return 0\n    elif name.startswith('patch_embed'):\n        return 0\n    elif name.startswith('blocks'):\n        return int(name.split('.')[1]) + 1\n    else:\n        return num_layers"
  },
  {
    "path": "linear_util/lr_sched.py",
    "content": "import math\n\ndef adjust_learning_rate(optimizer, epoch, args):\n    \"\"\"Decay the learning rate with half-cycle cosine after warmup\"\"\"\n    if epoch < args.warmup_epochs:\n        lr = args.lr * epoch / args.warmup_epochs \n    else:\n        lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \\\n            (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs)))\n    for param_group in optimizer.param_groups:\n        if \"lr_scale\" in param_group:\n            param_group[\"lr\"] = lr * param_group[\"lr_scale\"]\n        else:\n            param_group[\"lr\"] = lr\n    return lr\n"
  },
  {
    "path": "linear_util/misc.py",
    "content": "import builtins\nimport datetime\nimport os\nimport time\nfrom collections import defaultdict, deque\nfrom pathlib import Path\n\nimport torch\nimport torch.distributed as dist\nfrom torch._six import inf\n\n\nclass SmoothedValue(object):\n    \"\"\"Track a series of values and provide access to smoothed values over a\n    window or the global series average.\n    \"\"\"\n\n    def __init__(self, window_size=20, fmt=None):\n        if fmt is None:\n            fmt = \"{median:.4f} ({global_avg:.4f})\"\n        self.deque = deque(maxlen=window_size)\n        self.total = 0.0\n        self.count = 0\n        self.fmt = fmt\n\n    def update(self, value, n=1):\n        self.deque.append(value)\n        self.count += n\n        self.total += value * n\n\n    def synchronize_between_processes(self):\n        \"\"\"\n        Warning: does not synchronize the deque!\n        \"\"\"\n        if not is_dist_avail_and_initialized():\n            return\n        t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')\n        dist.barrier()\n        dist.all_reduce(t)\n        t = t.tolist()\n        self.count = int(t[0])\n        self.total = t[1]\n\n    @property\n    def median(self):\n        d = torch.tensor(list(self.deque))\n        return d.median().item()\n\n    @property\n    def avg(self):\n        d = torch.tensor(list(self.deque), dtype=torch.float32)\n        return d.mean().item()\n\n    @property\n    def global_avg(self):\n        return self.total / self.count\n\n    @property\n    def max(self):\n        return max(self.deque)\n\n    @property\n    def value(self):\n        return self.deque[-1]\n\n    def __str__(self):\n        return self.fmt.format(\n            median=self.median,\n            avg=self.avg,\n            global_avg=self.global_avg,\n            max=self.max,\n            value=self.value)\n\n\nclass MetricLogger(object):\n    def __init__(self, delimiter=\"\\t\"):\n        self.meters = defaultdict(SmoothedValue)\n        self.delimiter = delimiter\n\n    def update(self, **kwargs):\n        for k, v in kwargs.items():\n            if v is None:\n                continue\n            if isinstance(v, torch.Tensor):\n                v = v.item()\n            assert isinstance(v, (float, int))\n            self.meters[k].update(v)\n\n    def __getattr__(self, attr):\n        if attr in self.meters:\n            return self.meters[attr]\n        if attr in self.__dict__:\n            return self.__dict__[attr]\n        raise AttributeError(\"'{}' object has no attribute '{}'\".format(\n            type(self).__name__, attr))\n\n    def __str__(self):\n        loss_str = []\n        for name, meter in self.meters.items():\n            loss_str.append(\n                \"{}: {}\".format(name, str(meter))\n            )\n        return self.delimiter.join(loss_str)\n\n    def synchronize_between_processes(self):\n        for meter in self.meters.values():\n            meter.synchronize_between_processes()\n\n    def add_meter(self, name, meter):\n        self.meters[name] = meter\n\n    def log_every(self, iterable, print_freq, header=None):\n        i = 0\n        if not header:\n            header = ''\n        start_time = time.time()\n        end = time.time()\n        iter_time = SmoothedValue(fmt='{avg:.4f}')\n        data_time = SmoothedValue(fmt='{avg:.4f}')\n        space_fmt = ':' + str(len(str(len(iterable)))) + 'd'\n        log_msg = [\n            header,\n            '[{0' + space_fmt + '}/{1}]',\n            'eta: {eta}',\n            '{meters}',\n            'time: {time}',\n            'data: {data}'\n        ]\n        if torch.cuda.is_available():\n            log_msg.append('max mem: {memory:.0f}')\n        log_msg = self.delimiter.join(log_msg)\n        MB = 1024.0 * 1024.0\n        for obj in iterable:\n            data_time.update(time.time() - end)\n            yield obj\n            iter_time.update(time.time() - end)\n            if i % print_freq == 0 or i == len(iterable) - 1:\n                eta_seconds = iter_time.global_avg * (len(iterable) - i)\n                eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))\n                if torch.cuda.is_available():\n                    print(log_msg.format(\n                        i, len(iterable), eta=eta_string,\n                        meters=str(self),\n                        time=str(iter_time), data=str(data_time),\n                        memory=torch.cuda.max_memory_allocated() / MB))\n                else:\n                    print(log_msg.format(\n                        i, len(iterable), eta=eta_string,\n                        meters=str(self),\n                        time=str(iter_time), data=str(data_time)))\n            i += 1\n            end = time.time()\n        total_time = time.time() - start_time\n        total_time_str = str(datetime.timedelta(seconds=int(total_time)))\n        print('{} Total time: {} ({:.4f} s / it)'.format(\n            header, total_time_str, total_time / len(iterable)))\n\n\ndef setup_for_distributed(is_master):\n    \"\"\"\n    This function disables printing when not in master process\n    \"\"\"\n    builtin_print = builtins.print\n\n    def print(*args, **kwargs):\n        force = kwargs.pop('force', False)\n        force = force or (get_world_size() > 8)\n        if is_master or force:\n            now = datetime.datetime.now().time()\n            builtin_print('[{}] '.format(now), end='')  # print with time stamp\n            builtin_print(*args, **kwargs)\n\n    builtins.print = print\n\n\ndef is_dist_avail_and_initialized():\n    if not dist.is_available():\n        return False\n    if not dist.is_initialized():\n        return False\n    return True\n\n\ndef get_world_size():\n    if not is_dist_avail_and_initialized():\n        return 1\n    return dist.get_world_size()\n\n\ndef get_rank():\n    if not is_dist_avail_and_initialized():\n        return 0\n    return dist.get_rank()\n\n\ndef is_main_process():\n    return get_rank() == 0\n\n\ndef save_on_master(*args, **kwargs):\n    if is_main_process():\n        torch.save(*args, **kwargs)\n\n\ndef init_distributed_mode(args):\n    if args.dist_on_itp:\n        args.rank = int(os.environ['OMPI_COMM_WORLD_RANK'])\n        args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])\n        args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])\n        args.dist_url = \"tcp://%s:%s\" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT'])\n        os.environ['LOCAL_RANK'] = str(args.gpu)\n        os.environ['RANK'] = str(args.rank)\n        os.environ['WORLD_SIZE'] = str(args.world_size)\n        # [\"RANK\", \"WORLD_SIZE\", \"MASTER_ADDR\", \"MASTER_PORT\", \"LOCAL_RANK\"]\n    elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:\n        args.rank = int(os.environ[\"RANK\"])\n        args.world_size = int(os.environ['WORLD_SIZE'])\n        args.gpu = int(os.environ['LOCAL_RANK'])\n    elif 'SLURM_PROCID' in os.environ:\n        args.rank = int(os.environ['SLURM_PROCID'])\n        args.gpu = args.rank % torch.cuda.device_count()\n    else:\n        print('Not using distributed mode')\n        setup_for_distributed(is_master=True)  # hack\n        args.distributed = False\n        return\n\n    args.distributed = True\n\n    torch.cuda.set_device(args.gpu)\n    args.dist_backend = 'nccl'\n    print('| distributed init (rank {}): {}, gpu {}'.format(\n        args.rank, args.dist_url, args.gpu), flush=True)\n    torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,\n                                         world_size=args.world_size, rank=args.rank)\n    torch.distributed.barrier()\n    setup_for_distributed(args.rank == 0)\n\n\nclass NativeScalerWithGradNormCount:\n    state_dict_key = \"amp_scaler\"\n\n    def __init__(self):\n        self._scaler = torch.cuda.amp.GradScaler()\n\n    def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True):\n        self._scaler.scale(loss).backward(create_graph=create_graph)\n        if update_grad:\n            if clip_grad is not None:\n                assert parameters is not None\n                self._scaler.unscale_(optimizer)  # unscale the gradients of optimizer's assigned params in-place\n                norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)\n            else:\n                self._scaler.unscale_(optimizer)\n                norm = get_grad_norm_(parameters)\n            self._scaler.step(optimizer)\n            self._scaler.update()\n        else:\n            norm = None\n        return norm\n\n    def state_dict(self):\n        return self._scaler.state_dict()\n\n    def load_state_dict(self, state_dict):\n        self._scaler.load_state_dict(state_dict)\n\n\ndef get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:\n    if isinstance(parameters, torch.Tensor):\n        parameters = [parameters]\n    parameters = [p for p in parameters if p.grad is not None]\n    norm_type = float(norm_type)\n    if len(parameters) == 0:\n        return torch.tensor(0.)\n    device = parameters[0].grad.device\n    if norm_type == inf:\n        total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)\n    else:\n        total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)\n    return total_norm\n\n\ndef save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler, exp_name=None):\n    output_dir = Path(args.output_dir)\n    epoch_name = str(epoch)\n    if loss_scaler is not None:\n        if exp_name is not None:\n            checkpoint_paths = [output_dir / ('{}_checkpoint-{}.pth'.format(exp_name, epoch_name))]\n        else:\n            checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)]\n\n        for checkpoint_path in checkpoint_paths:\n            to_save = {\n                'model': model_without_ddp.state_dict(),\n                'optimizer': optimizer.state_dict(),\n                'epoch': epoch,\n                'scaler': loss_scaler.state_dict(),\n                'args': args,\n            }\n\n            save_on_master(to_save, checkpoint_path)\n    else:\n        client_state = {'epoch': epoch}\n        if exp_name is not None:\n            model.save_checkpoint(save_dir=args.output_dir, tag=\"{}_checkpoint-{}\".format(exp_name, epoch_name), client_state=client_state)\n        else:\n            model.save_checkpoint(save_dir=args.output_dir, tag=\"checkpoint-%s\" % epoch_name, client_state=client_state)\n\n\ndef load_model(args, model_without_ddp, optimizer, loss_scaler):\n    if args.resume:\n        if args.resume.startswith('https'):\n            checkpoint = torch.hub.load_state_dict_from_url(\n                args.resume, map_location='cpu', check_hash=True)\n        else:\n            checkpoint = torch.load(args.resume, map_location='cpu')\n\n        checkpoint_model = checkpoint['model']\n        model_without_ddp.load_state_dict(checkpoint_model)\n        print(\"Resume checkpoint %s\" % args.resume)\n        if 'optimizer' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'eval') and args.eval):\n            optimizer.load_state_dict(checkpoint['optimizer'])\n            args.start_epoch = checkpoint['epoch'] + 1\n            if 'scaler' in checkpoint:\n                loss_scaler.load_state_dict(checkpoint['scaler'])\n            print(\"With optim & sched!\")\n\n\ndef all_reduce_mean(x):\n    world_size = get_world_size()\n    if world_size > 1:\n        x_reduce = torch.tensor(x).cuda()\n        dist.all_reduce(x_reduce)\n        x_reduce /= world_size\n        return x_reduce.item()\n    else:\n        return x"
  },
  {
    "path": "linear_util/pos_embed.py",
    "content": "import numpy as np\n\nimport torch\n\ndef get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):\n    \"\"\"\n    grid_size: int of the grid height and width\n    return:\n    pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)\n    \"\"\"\n    grid_h = np.arange(grid_size, dtype=np.float32)\n    grid_w = np.arange(grid_size, dtype=np.float32)\n    grid = np.meshgrid(grid_w, grid_h)  # here w goes first\n    grid = np.stack(grid, axis=0)\n\n    grid = grid.reshape([2, 1, grid_size, grid_size])\n    pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)\n    if cls_token:\n        pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)\n    return pos_embed\n\n\ndef get_2d_sincos_pos_embed_from_grid(embed_dim, grid):\n    assert embed_dim % 2 == 0\n\n    # use half of dimensions to encode grid_h\n    emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])  # (H*W, D/2)\n    emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])  # (H*W, D/2)\n\n    emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)\n    return emb\n\n\ndef get_1d_sincos_pos_embed_from_grid(embed_dim, pos):\n    \"\"\"\n    embed_dim: output dimension for each position\n    pos: a list of positions to be encoded: size (M,)\n    out: (M, D)\n    \"\"\"\n    assert embed_dim % 2 == 0\n    omega = np.arange(embed_dim // 2, dtype=np.float)\n    omega /= embed_dim / 2.\n    omega = 1. / 10000**omega  # (D/2,)\n\n    pos = pos.reshape(-1)  # (M,)\n    out = np.einsum('m,d->md', pos, omega)  # (M, D/2), outer product\n\n    emb_sin = np.sin(out) # (M, D/2)\n    emb_cos = np.cos(out) # (M, D/2)\n\n    emb = np.concatenate([emb_sin, emb_cos], axis=1)  # (M, D)\n    return emb\n\n\n# --------------------------------------------------------\n# Interpolate position embeddings for high-resolution\n# References:\n# DeiT: https://github.com/facebookresearch/deit\n# --------------------------------------------------------\ndef interpolate_pos_embed(model, checkpoint_model):\n    if 'pos_embed' in checkpoint_model:\n        pos_embed_checkpoint = checkpoint_model['pos_embed']\n        embedding_size = pos_embed_checkpoint.shape[-1]\n        num_patches = model.patch_embed.num_patches\n        num_extra_tokens = model.pos_embed.shape[-2] - num_patches\n        # height (== width) for the checkpoint position embedding\n        orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)\n        # height (== width) for the new position embedding\n        new_size = int(num_patches ** 0.5)\n        # class_token and dist_token are kept unchanged\n        if orig_size != new_size:\n            print(\"Position interpolate from %dx%d to %dx%d\" % (orig_size, orig_size, new_size, new_size))\n            extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]\n            # only the position tokens are interpolated\n            pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]\n            pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)\n            pos_tokens = torch.nn.functional.interpolate(\n                pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)\n            pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)\n            new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)\n            checkpoint_model['pos_embed'] = new_pos_embed\n"
  },
  {
    "path": "models/modeling_cae.py",
    "content": "import math\nimport time\nimport torch\nimport torch.nn as nn\nfrom functools import partial\n\nfrom models.modeling_finetune import _cfg, PatchEmbed\nfrom timm.models.registry import register_model\nfrom timm.models.layers import trunc_normal_ as __call_trunc_normal_\nfrom models.modeling_cae_helper import *\n\ndef trunc_normal_(tensor, mean=0., std=1.):\n    __call_trunc_normal_(tensor, mean=mean, std=std, a=-std, b=std)\n\n\nclass VisionTransformerForMaskedImageModeling(nn.Module):\n    def __init__(self, img_size=224, patch_size=16, in_chans=3, vocab_size=8192, embed_dim=768, depth=12,\n                 num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0.,\n                 drop_path_rate=0., norm_layer=None, init_values=None, attn_head_dim=None,\n                 use_abs_pos_emb=True, init_std=0.02, args=None, **kwargs):\n        super().__init__()\n\n        self.encoder = VisionTransformerEncoder(img_size=img_size, patch_size=patch_size, in_chans=in_chans, \n                 vocab_size=vocab_size, embed_dim=embed_dim, depth=depth,\n                 num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,\n                 drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, drop_path_rate=drop_path_rate,\n                 norm_layer=norm_layer, init_values=init_values, attn_head_dim=attn_head_dim,\n                 use_abs_pos_emb=use_abs_pos_emb, init_std=init_std, args=args)\n\n        # alignment constraint\n        self.teacher = VisionTransformerEncoder(img_size=img_size, patch_size=patch_size, in_chans=in_chans, \n                vocab_size=vocab_size, embed_dim=embed_dim, depth=depth,\n                num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,\n                drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, drop_path_rate=drop_path_rate,\n                norm_layer=norm_layer, init_values=init_values, attn_head_dim=attn_head_dim,\n                use_abs_pos_emb=use_abs_pos_emb, init_std=init_std, args=args)\n\n        self.init_std = init_std\n        self.args = args\n        self.num_patches = self.encoder.patch_embed.num_patches\n\n        self.pretext_neck = VisionTransformerNeck(patch_size=patch_size, num_classes=args.decoder_num_classes, embed_dim=args.decoder_embed_dim, depth=args.regressor_depth,\n            num_heads=args.decoder_num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop_rate=drop_rate, attn_drop_rate=attn_drop_rate,\n            drop_path_rate=drop_path_rate, norm_layer=norm_layer, init_values=args.decoder_layer_scale_init_value, num_patches=self.num_patches, init_std=init_std, args=args)\n\n        # encoder to decoder projection, borrowed from mae.\n        if args.decoder_embed_dim != embed_dim:\n            self.encoder_to_decoder = nn.Linear(embed_dim, args.decoder_embed_dim, bias=True)\n            self.encoder_to_decoder_norm = norm_layer(args.decoder_embed_dim)\n        else:\n            self.encoder_to_decoder = None\n\n        self.mask_token = nn.Parameter(torch.zeros(1, 1, args.decoder_embed_dim))\n        trunc_normal_(self.mask_token, std=self.init_std)\n\n        ### whether to use 'rescale' to init the weight, borrowed from beit.\n        if not args.fix_init_weight:\n            self.apply(self._init_weights)\n        self._init_teacher()\n        \n        \n    def _init_teacher(self):  \n        # init the weights of teacher with those of backbone\n        for param_encoder, param_teacher in zip(self.encoder.parameters(), self.teacher.parameters()):\n            param_teacher.detach()\n            param_teacher.data.copy_(param_encoder.data)\n            param_teacher.requires_grad = False\n\n    def momentum_update(self, base_momentum=0):\n        \"\"\"Momentum update of the teacher network.\"\"\"\n        for param_encoder, param_teacher in zip(self.encoder.parameters(),\n                                                self.teacher.parameters()):\n            param_teacher.data = param_teacher.data * base_momentum + \\\n                param_encoder.data * (1. - base_momentum)\n\n    def _init_weights(self, m):\n        if isinstance(m, nn.Linear):\n            nn.init.xavier_uniform_(m.weight)\n            if isinstance(m, nn.Linear) and m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n        elif isinstance(m, nn.LayerNorm):\n            nn.init.constant_(m.bias, 0)\n            nn.init.constant_(m.weight, 1.0)\n    \n    '''\n    Input shape:\n        x: [bs, 3, 224, 224]\n        bool_masked_pos: [bs, num_patch * num_patch]\n    '''\n    def forward(self, x, bool_masked_pos, return_all_tokens=None):\n        batch_size = x.size(0)\n\n        '''\n        Encoder\n        Output shape:\n            [bs, num_visible + 1, C]\n        '''\n        x_unmasked = self.encoder(x, bool_masked_pos=bool_masked_pos)\n\n        # encoder to decoder projection\n        if self.encoder_to_decoder is not None:\n            x_unmasked = self.encoder_to_decoder(x_unmasked)\n            x_unmasked = self.encoder_to_decoder_norm(x_unmasked)\n\n        '''\n        Alignment constraint\n        '''\n        with torch.no_grad():\n            latent_target = self.teacher(x, bool_masked_pos=(~bool_masked_pos))\n            latent_target = latent_target[:, 1:, :] # remove class token\n            if self.encoder_to_decoder is not None:\n                latent_target = self.encoder_to_decoder_norm(self.encoder_to_decoder(latent_target.detach()))\n\n            self.momentum_update(self.args.base_momentum)\n\n        '''\n        Latent contextual regressor and decoder\n        '''\n        b, num_visible_plus1, dim = x_unmasked.shape\n        # remove class token\n        x_unmasked = x_unmasked[:, 1:, :]\n\n        num_masked_patches = self.num_patches - (num_visible_plus1-1)\n        \n        # generate position embeddings.\n        pos_embed = self.encoder.build_2d_sincos_position_embedding(dim, use_cls_token=True).expand(batch_size, self.num_patches+1, dim).cuda(x_unmasked.device)\n\n        # pos embed for masked patches\n        pos_embed_masked = pos_embed[:,1:][bool_masked_pos].reshape(batch_size, -1, dim) \n\n        # pos embed for unmasked patches\n        pos_embed_unmasked = pos_embed[:,1:][~bool_masked_pos].reshape(batch_size, -1, dim) \n\n        # masked embedding '''\n        x_masked = self.mask_token.expand(batch_size, num_masked_patches, -1)\n\n        logits, latent_pred = self.pretext_neck(x_masked, x_unmasked, pos_embed_masked, pos_embed_unmasked, bool_masked_pos)\n        logits = logits.view(-1, logits.shape[2])\n\n        return logits, latent_pred, latent_target\n\n\n@register_model\ndef cae_small_patch16_224_8k_vocab(pretrained=False, **kwargs):\n    model = VisionTransformerForMaskedImageModeling(\n        patch_size=16, embed_dim=384, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6), vocab_size=8192, **kwargs)\n    model.default_cfg = _cfg()\n    if pretrained:\n        checkpoint = torch.load(\n            kwargs[\"init_ckpt\"], map_location=\"cpu\"\n        )\n        model.load_state_dict(checkpoint[\"model\"])\n    return model\n\n\n@register_model\ndef cae_base_patch16_224_8k_vocab(pretrained=False, **kwargs):\n    model = VisionTransformerForMaskedImageModeling(\n        patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6), vocab_size=8192, **kwargs)\n    model.default_cfg = _cfg()\n    if pretrained:\n        checkpoint = torch.load(\n            kwargs[\"init_ckpt\"], map_location=\"cpu\"\n        )\n        model.load_state_dict(checkpoint[\"model\"])\n    return model\n\n\n@register_model\ndef cae_large_patch16_224_8k_vocab(pretrained=False, **kwargs):\n    model = VisionTransformerForMaskedImageModeling(\n        patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6), vocab_size=8192, **kwargs)\n    model.default_cfg = _cfg()\n    if pretrained:\n        checkpoint = torch.load(\n            kwargs[\"init_ckpt\"], map_location=\"cpu\"\n        )\n        model.load_state_dict(checkpoint[\"model\"])\n    return model\n"
  },
  {
    "path": "models/modeling_cae_helper.py",
    "content": "import math\nimport time\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom functools import partial\nfrom models.modeling_finetune import PatchEmbed, DropPath, Mlp\nfrom timm.models.registry import register_model\nfrom timm.models.layers import trunc_normal_ as __call_trunc_normal_\n\ndef trunc_normal_(tensor, mean=0., std=1.):\n    __call_trunc_normal_(tensor, mean=mean, std=std, a=-std, b=std)\n\nclass Attention(nn.Module):\n    def __init__(\n            self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,\n            proj_drop=0., window_size=None, attn_head_dim=None):\n        super().__init__()\n        self.num_heads = num_heads\n        head_dim = dim // num_heads\n        if attn_head_dim is not None:\n            head_dim = attn_head_dim\n        all_head_dim = head_dim * self.num_heads\n        self.scale = qk_scale or head_dim ** -0.5\n\n        self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)\n        if qkv_bias:\n            self.q_bias = nn.Parameter(torch.zeros(all_head_dim))\n            self.v_bias = nn.Parameter(torch.zeros(all_head_dim))\n        else:\n            self.q_bias = None\n            self.v_bias = None\n        \n        self.attn_drop = nn.Dropout(attn_drop)\n        self.proj = nn.Linear(all_head_dim, dim)\n        self.proj_drop = nn.Dropout(proj_drop)\n\n    def forward(self, x, bool_masked_pos=None):\n\n        B, N, C = x.shape\n        qkv_bias = None\n        if self.q_bias is not None:\n            qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))\n\n        qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)\n        qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)\n        q, k, v = qkv[0], qkv[1], qkv[2]\n        \n        q = q * self.scale\n        attn = (q @ k.transpose(-2, -1))    # (B, N_head, N, N)\n        \n        attn = attn.softmax(dim=-1)\n        attn = self.attn_drop(attn)\n\n        x = (attn @ v).transpose(1, 2).reshape(B, N, -1) \n        x = self.proj(x)\n        x = self.proj_drop(x)\n\n        return x\n\n'''\nModified from Attention()\n'''\nclass CrossAttention(nn.Module):\n    def __init__(\n            self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,\n            proj_drop=0., window_size=None, attn_head_dim=None):\n        super().__init__()\n        self.num_heads = num_heads\n        head_dim = dim // num_heads\n        if attn_head_dim is not None:\n            head_dim = attn_head_dim\n        all_head_dim = head_dim * self.num_heads\n        self.scale = qk_scale or head_dim ** -0.5\n\n        self.q = nn.Linear(dim, all_head_dim, bias=False)\n        self.k = nn.Linear(dim, all_head_dim, bias=False)\n        self.v = nn.Linear(dim, all_head_dim, bias=False)\n\n        if qkv_bias:\n            self.q_bias = nn.Parameter(torch.zeros(all_head_dim))\n            self.v_bias = nn.Parameter(torch.zeros(all_head_dim))\n        else:\n            self.q_bias = None\n            self.k_bias = None\n            self.v_bias = None\n\n        self.attn_drop = nn.Dropout(attn_drop)\n        self.proj = nn.Linear(all_head_dim, dim)\n        self.proj_drop = nn.Dropout(proj_drop)\n\n    def forward(self, x, bool_masked_pos=None, k=None, v=None):\n        B, N, C = x.shape\n        N_k = k.shape[1]\n        N_v = v.shape[1]\n\n        q_bias, k_bias, v_bias = None, None, None\n        if self.q_bias is not None:\n            q_bias = self.q_bias\n            k_bias = torch.zeros_like(self.v_bias, requires_grad=False)\n            v_bias = self.v_bias\n\n        q = F.linear(input=x, weight=self.q.weight, bias=q_bias)\n        q = q.reshape(B, N, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0)    # (B, N_head, N_q, dim)\n\n        k = F.linear(input=k, weight=self.k.weight, bias=k_bias)\n        k = k.reshape(B, N_k, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0)\n\n        v = F.linear(input=v, weight=self.v.weight, bias=v_bias)   \n        v = v.reshape(B, N_v, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0)\n\n        q = q * self.scale\n        attn = (q @ k.transpose(-2, -1))      # (B, N_head, N_q, N_k)\n        \n        attn = attn.softmax(dim=-1)\n        attn = self.attn_drop(attn)\n\n        x = (attn @ v).transpose(1, 2).reshape(B, N, -1) \n        x = self.proj(x)\n        x = self.proj_drop(x)\n\n        return x\n\n\nclass Block(nn.Module):\n\n    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,\n                 drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,\n                 window_size=None, attn_head_dim=None):\n        super().__init__()\n        self.norm1 = norm_layer(dim)\n        self.attn = Attention(\n            dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,\n            attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim)\n        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()\n        self.norm2 = norm_layer(dim)\n        mlp_hidden_dim = int(dim * mlp_ratio)\n        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)\n\n        if init_values > 0:\n            self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)\n            self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)\n        else:\n            self.gamma_1, self.gamma_2 = None, None\n\n    def forward(self, x, bool_masked_pos=None):\n        if self.gamma_1 is None:\n            x = x + self.drop_path(self.attn(self.norm1(x), bool_masked_pos))\n            x = x + self.drop_path(self.mlp(self.norm2(x)))\n        else:\n            x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), bool_masked_pos))\n            x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))\n\n        return x\n\n\nclass RegressorBlock(nn.Module):\n    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,\n                 drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,\n                 window_size=None, attn_head_dim=None):\n        super().__init__()\n        self.norm1_q = norm_layer(dim)\n        self.norm1_k = norm_layer(dim)\n        self.norm1_v = norm_layer(dim)\n        self.norm2_cross = norm_layer(dim)\n        self.cross_attn =  CrossAttention(\n            dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,\n            attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim)\n\n        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()\n        mlp_hidden_dim = int(dim * mlp_ratio)\n        \n        self.mlp_cross = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)\n\n        if init_values > 0:\n            self.gamma_1_cross = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)\n            self.gamma_2_cross = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)\n        else:\n            self.gamma_1_cross = nn.Parameter(torch.ones((dim)),requires_grad=False)\n            self.gamma_2_cross = nn.Parameter(torch.ones((dim)),requires_grad=False)\n\n    def forward(self, x_q, x_kv, pos_q, pos_k, bool_masked_pos):\n        x = x_q + self.drop_path(self.gamma_1_cross * self.cross_attn(self.norm1_q(x_q + pos_q),\n         bool_masked_pos, k=self.norm1_k(x_kv + pos_k), v=self.norm1_v(x_kv)))\n        x = self.norm2_cross(x)\n        x = x + self.drop_path(self.gamma_2_cross * self.mlp_cross(x))\n\n        return x\n\n\n'''\nEncoder that extracts representations\n'''\nclass VisionTransformerEncoder(nn.Module):\n    def __init__(self, img_size=224, patch_size=16, in_chans=3, vocab_size=8192, embed_dim=768, depth=12,\n                 num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0.,\n                 drop_path_rate=0., norm_layer=None, init_values=None, attn_head_dim=None,\n                 use_abs_pos_emb=True, init_std=0.02, args=None, **kwargs):\n        super().__init__()\n        self.num_features = self.embed_dim = embed_dim\n\n        self.patch_embed = PatchEmbed(\n            img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)\n        num_patches = self.patch_embed.num_patches\n        self.num_patches = num_patches\n\n        # generate class token and pos embed\n        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))\n        self.pos_embed = self.build_2d_sincos_position_embedding(embed_dim, use_cls_token=True)\n        self.pos_drop = nn.Dropout(p=drop_rate)\n\n        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]\n        self.blocks = nn.ModuleList([\n            Block(\n                dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,\n                drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,\n                init_values=init_values, window_size=None,\n                attn_head_dim=attn_head_dim,\n            )\n            for i in range(depth)])\n        self.norm = norm_layer(embed_dim)\n\n        self.init_std = init_std\n\n        # init the model\n        trunc_normal_(self.cls_token, std=self.init_std)\n        self.apply(self._init_weights)\n        # rescale init function from beit\n        # if it is not activated, it will be overwritten\n        self.fix_init_weight()\n\n    def build_2d_sincos_position_embedding(self, embed_dim=768, temperature=10000., use_cls_token=False):\n        h, w = self.patch_embed.patch_shape\n        grid_w = torch.arange(w, dtype=torch.float32)\n        grid_h = torch.arange(h, dtype=torch.float32)\n        grid_w, grid_h = torch.meshgrid(grid_w, grid_h)\n        assert embed_dim % 4 == 0, 'Embed dimension must be divisible by 4 for 2D sin-cos position embedding'\n        pos_dim = embed_dim // 4\n        omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim\n        omega = 1. / (temperature ** omega)\n        out_w = torch.einsum('m,d->md', [grid_w.flatten(), omega])\n        out_h = torch.einsum('m,d->md', [grid_h.flatten(), omega])\n        pos_emb = torch.cat([torch.sin(out_w), torch.cos(out_w), torch.sin(out_h), torch.cos(out_h)], dim=1)[None, :, :]\n\n        if not use_cls_token:\n            pos_embed = nn.Parameter(pos_emb)\n        else:\n            pe_token = torch.zeros([1, 1, embed_dim], dtype=torch.float32)\n            pos_embed = nn.Parameter(torch.cat([pe_token, pos_emb], dim=1))\n        pos_embed.requires_grad = False\n        return pos_embed\n\n    def fix_init_weight(self):\n        def rescale(param, layer_id):\n            param.div_(math.sqrt(2.0 * layer_id))\n\n        for layer_id, layer in enumerate(self.blocks):\n            rescale(layer.attn.proj.weight.data, layer_id + 1)\n            rescale(layer.mlp.fc2.weight.data, layer_id + 1)\n\n    def _init_weights(self, m):\n        if isinstance(m, nn.Linear):\n            trunc_normal_(m.weight, std=self.init_std)\n            if isinstance(m, nn.Linear) and m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n        elif isinstance(m, nn.LayerNorm):\n            nn.init.constant_(m.bias, 0)\n            nn.init.constant_(m.weight, 1.0)\n        elif isinstance(m, nn.Conv2d):\n            trunc_normal_(m.weight, std=self.init_std)\n            if m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n\n    @torch.jit.ignore\n    def no_weight_decay(self):\n        return {'pos_embed', 'cls_token'}\n\n    def get_num_layers(self):\n        return len(self.blocks)\n\n    def forward_features(self, x, bool_masked_pos):\n        x = self.patch_embed(x, bool_masked_pos=bool_masked_pos)\n        batch_size, seq_len, dim = x.size()\n\n        cls_tokens = self.cls_token.expand(batch_size, -1, -1)\n\n        # unmasked embeddings\n        x_unmasked = x[~bool_masked_pos].reshape(batch_size, -1, dim)\n        x_unmasked = torch.cat((cls_tokens, x_unmasked), dim=1)\n\n        if self.pos_embed is not None:\n            pos_embed = self.pos_embed.expand(batch_size, self.num_patches+1, dim)\n            pos_embed_unmasked = pos_embed[:,1:][~bool_masked_pos].reshape(batch_size, -1, dim) \n            pos_embed_unmasked = torch.cat((pos_embed[:,:1], pos_embed_unmasked),dim=1)\n            x_unmasked = x_unmasked + pos_embed_unmasked\n\n        x_unmasked = self.pos_drop(x_unmasked)\n\n        for blk in self.blocks:\n            x_unmasked = blk(x_unmasked, bool_masked_pos)\n\n        x_unmasked = self.norm(x_unmasked)\n\n        return x_unmasked\n\n    def forward(self, x, bool_masked_pos, return_all_tokens=False):\n        x = self.forward_features(x, bool_masked_pos=bool_masked_pos)\n        return x\n\n'''\nLatent context regressor + decoder that solves the pretext task.\n'''\nclass VisionTransformerNeck(nn.Module):\n    def __init__(self, patch_size=16, num_classes=8192, embed_dim=768, depth=6, \n                 num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0.,\n                 drop_path_rate=0., norm_layer=None, init_values=None, num_patches=196, init_std=0.02, args=None, patch_shape=(14,14)):\n        super().__init__()\n\n        self.num_features = self.embed_dim = embed_dim\n        self.patch_size = patch_size\n        self.args = args\n\n        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]\n\n        # context regressor\n        self.regressor_blocks = nn.ModuleList([\n            RegressorBlock(\n                dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,\n                drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,\n                init_values=init_values)\n            for i in range(depth)])\n\n        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, args.decoder_depth)]\n        self.decoder_blocks = nn.ModuleList([\n            Block(\n                dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,\n                drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,\n                init_values=init_values)\n            for i in range(args.decoder_depth)])\n\n        self.norm = norm_layer(embed_dim)\n        self.norm2 = norm_layer(embed_dim)\n        self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()\n        \n        self.init_std = init_std\n\n        # init the model\n        trunc_normal_(self.head.weight, std=self.init_std)\n        self.apply(self._init_weights)\n        self.fix_init_weight()\n\n    def fix_init_weight(self):\n        def rescale(param, layer_id):\n            param.div_(math.sqrt(2.0 * layer_id))\n\n        for layer_id, layer in enumerate(self.regressor_blocks):\n            rescale(layer.cross_attn.proj.weight.data, layer_id + 1)\n            rescale(layer.mlp_cross.fc2.weight.data, layer_id + 1)\n\n        for layer_id, layer in enumerate(self.decoder_blocks):\n            rescale(layer.attn.proj.weight.data, layer_id + 1)\n            rescale(layer.mlp.fc2.weight.data, layer_id + 1)\n\n    def _init_weights(self, m):\n        if isinstance(m, nn.Linear):\n            trunc_normal_(m.weight, std=self.init_std)\n            if isinstance(m, nn.Linear) and m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n        elif isinstance(m, nn.LayerNorm):\n            nn.init.constant_(m.bias, 0)\n            nn.init.constant_(m.weight, 1.0)\n        elif isinstance(m, nn.Conv2d):\n            trunc_normal_(m.weight, std=self.init_std)\n            if m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n\n    @torch.jit.ignore\n    def no_weight_decay(self):\n        return {'pos_embed', 'cls_token'}\n        \n    def forward(self, x_masked, x_unmasked, pos_embed_masked, pos_embed_unmasked, bool_masked_pos):                \n        # latent contextual regressor\n        for blk in self.regressor_blocks:\n            x_masked = blk(x_masked, torch.cat([x_unmasked, x_masked], dim=1), pos_embed_masked, torch.cat([pos_embed_unmasked, pos_embed_masked], dim=1), bool_masked_pos)\n        x_masked = self.norm(x_masked)\n        latent_pred = x_masked\n        \n        x_masked = x_masked + pos_embed_masked  # add pos embed, like encoder\n        for blk in self.decoder_blocks:\n            x_masked = blk(x_masked)\n        x_masked = self.norm2(x_masked)\n\n        logits = self.head(x_masked)\n        \n        return logits, latent_pred\n"
  },
  {
    "path": "models/modeling_discrete_vae.py",
    "content": "# --------------------------------------------------------\n# BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254)\n# Github source: https://github.com/microsoft/unilm/tree/master/beit\n# Copyright (c) 2021 Microsoft\n# Licensed under The MIT License [see LICENSE for details]\n# By Hangbo Bao\n# Based on OpenAI DALL-E and lucidrains' DALLE-pytorch code bases\n# https://github.com/openai/DALL-E\n# https://github.com/lucidrains/DALLE-pytorch\n# --------------------------------------------------------'\nfrom math import sqrt\nimport os\nimport torch\nfrom torch import nn, einsum\nimport torch.nn.functional as F\nfrom einops import rearrange\n\n\ndef top_k(logits, thres = 0.5):\n    num_logits = logits.shape[-1]\n    k = max(int((1 - thres) * num_logits), 1)\n    val, ind = torch.topk(logits, k)\n    probs = torch.full_like(logits, float('-inf'))\n    probs.scatter_(1, ind, val)\n    return probs\n\n\ndef exists(val):\n    return val is not None\n\n\ndef default(val, d):\n    return val if exists(val) else d\n\n\ndef eval_decorator(fn):\n    def inner(model, *args, **kwargs):\n        was_training = model.training\n        model.eval()\n        out = fn(model, *args, **kwargs)\n        model.train(was_training)\n        return out\n    return inner\n\n\nclass BasicVAE(nn.Module):\n\n    def get_codebook_indices(self, images):\n        raise NotImplementedError()\n\n    def decode(self, img_seq):\n        raise NotImplementedError()\n\n    def get_codebook_probs(self, img_seq):\n        raise NotImplementedError()\n\n    def get_image_tokens_size(self):\n        pass\n\n    def get_image_size(self):\n        pass\n\n\n\nclass ResBlock(nn.Module):\n    def __init__(self, chan):\n        super().__init__()\n        self.net = nn.Sequential(\n            nn.Conv2d(chan, chan, 3, padding=1),\n            nn.ReLU(),\n            nn.Conv2d(chan, chan, 3, padding=1),\n            nn.ReLU(),\n            nn.Conv2d(chan, chan, 1)\n        )\n\n    def forward(self, x):\n        return self.net(x) + x\n\n\n\n\nclass DiscreteVAE(BasicVAE):\n    def __init__(\n        self,\n        image_size = 256,\n        num_tokens = 512,\n        codebook_dim = 512,\n        num_layers = 3,\n        num_resnet_blocks = 2,\n        hidden_dim = 64,\n        channels = 3,\n        smooth_l1_loss = False,\n        temperature = 0.9,\n        straight_through = False,\n        kl_div_loss_weight = 0.,\n    ):\n        super().__init__()\n        # assert log2(image_size).is_integer(), 'image size must be a power of 2'\n        assert num_layers >= 1, 'number of layers must be greater than or equal to 1'\n        has_resblocks = num_resnet_blocks > 0\n\n        self.image_size = image_size\n        self.num_tokens = num_tokens\n        self.num_layers = num_layers\n        self.temperature = temperature\n        self.straight_through = straight_through\n        self.codebook = nn.Embedding(num_tokens, codebook_dim)\n\n        hdim = hidden_dim\n\n        enc_chans = [hidden_dim] * num_layers\n        dec_chans = list(reversed(enc_chans))\n\n        enc_chans = [channels, *enc_chans]\n\n        dec_init_chan = codebook_dim if not has_resblocks else dec_chans[0]\n        dec_chans = [dec_init_chan, *dec_chans]\n\n        enc_chans_io, dec_chans_io = map(lambda t: list(zip(t[:-1], t[1:])), (enc_chans, dec_chans))\n\n        enc_layers = []\n        dec_layers = []\n\n        for (enc_in, enc_out), (dec_in, dec_out) in zip(enc_chans_io, dec_chans_io):\n            enc_layers.append(nn.Sequential(nn.Conv2d(enc_in, enc_out, 4, stride = 2, padding = 1), nn.ReLU()))\n            dec_layers.append(nn.Sequential(nn.ConvTranspose2d(dec_in, dec_out, 4, stride = 2, padding = 1), nn.ReLU()))\n\n        for _ in range(num_resnet_blocks):\n            dec_layers.insert(0, ResBlock(dec_chans[1]))\n            enc_layers.append(ResBlock(enc_chans[-1]))\n\n        if num_resnet_blocks > 0:\n            dec_layers.insert(0, nn.Conv2d(codebook_dim, dec_chans[1], 1))\n\n        enc_layers.append(nn.Conv2d(enc_chans[-1], num_tokens, 1))\n        dec_layers.append(nn.Conv2d(dec_chans[-1], channels, 1))\n\n        self.encoder = nn.Sequential(*enc_layers)\n        self.decoder = nn.Sequential(*dec_layers)\n\n        self.loss_fn = F.smooth_l1_loss if smooth_l1_loss else F.mse_loss\n        self.kl_div_loss_weight = kl_div_loss_weight\n\n    def get_image_size(self):\n        return self.image_size\n\n    def get_image_tokens_size(self):\n        return self.image_size // 8\n\n    @torch.no_grad()\n    @eval_decorator\n    def get_codebook_indices(self, images):\n        logits = self.forward(images, return_logits = True)\n        codebook_indices = logits.argmax(dim = 1).flatten(1)\n        return codebook_indices\n\n    @torch.no_grad()\n    @eval_decorator\n    def get_codebook_probs(self, images, temp):\n        logits = self.forward(images, return_logits = True)\n        return nn.Softmax(dim=1)(logits / temp)\n\n    def decode(\n        self,\n        img_seq\n    ):\n        image_embeds = self.codebook(img_seq)\n        b, n, d = image_embeds.shape\n        h = w = int(sqrt(n))\n\n        image_embeds = rearrange(image_embeds, 'b (h w) d -> b d h w', h = h, w = w)\n        images = self.decoder(image_embeds)\n        return images\n\n    def forward(\n        self,\n        img,\n        return_loss = False,\n        return_recons = False,\n        return_logits = False,\n        temp = None\n    ):\n        device, num_tokens, image_size, kl_div_loss_weight = img.device, self.num_tokens, self.image_size, self.kl_div_loss_weight\n        assert img.shape[-1] == image_size and img.shape[-2] == image_size, f'input must have the correct image size {image_size}'\n\n        logits = self.encoder(img)\n\n        if return_logits:\n            return logits # return logits for getting hard image indices for DALL-E training\n\n        temp = default(temp, self.temperature)\n        soft_one_hot = F.gumbel_softmax(logits.float(), tau = temp, dim = 1, hard = self.straight_through)\n        sampled = einsum('b n h w, n d -> b d h w', soft_one_hot, self.codebook.weight).type_as(logits)\n        out = self.decoder(sampled)\n\n        if not return_loss:\n            return out\n\n        # reconstruction loss\n\n        recon_loss = self.loss_fn(img, out)\n\n        # kl divergence\n\n        logits = rearrange(logits, 'b n h w -> b (h w) n')\n        _C = logits.size(-1)\n        avg_probs = F.softmax(logits.contiguous().view(-1, _C), dim=-1, dtype=torch.float32).mean(0)\n        diversity_loss = torch.sum(avg_probs * torch.log(avg_probs + 1e-6), dim=-1).mean()\n\n        if not return_recons:\n            return recon_loss, diversity_loss\n\n        return recon_loss, diversity_loss, out\n\n\n\nfrom dall_e import load_model\n\n\nclass Dalle_VAE(BasicVAE):\n    def __init__(self, image_size):\n        super().__init__()\n        self.encoder = None\n        self.decoder = None\n        self.image_size = image_size\n\n    def load_model(self, model_dir, device):\n        self.encoder = load_model(os.path.join(model_dir, \"encoder.pkl\"), device)\n        self.decoder = load_model(os.path.join(model_dir, \"decoder.pkl\"), device)\n\n    def decode(self, img_seq):\n        bsz = img_seq.size()[0]\n        img_seq = img_seq.view(bsz, self.image_size // 8, self.image_size // 8)\n        z = F.one_hot(img_seq, num_classes=self.encoder.vocab_size).permute(0, 3, 1, 2).float()\n        return self.decoder(z).float()\n\n    def get_codebook_indices(self, images):\n        z_logits = self.encoder(images)\n        return torch.argmax(z_logits, axis=1)\n\n    def get_codebook_probs(self, images):\n        z_logits = self.encoder(images)\n        return nn.Softmax(dim=1)(z_logits)\n\n    def forward(self, img_seq_prob, no_process=False):\n        if no_process:\n            return self.decoder(img_seq_prob.float()).float()\n        else:\n            bsz, seq_len, num_class = img_seq_prob.size()\n            z = img_seq_prob.view(bsz, self.image_size // 8, self.image_size // 8, self.encoder.vocab_size)\n            return self.decoder(z.permute(0, 3, 1, 2).float()).float()\n\n\nclass VGGAN(BasicVAE):\n    def __init__(self, image_size):\n        super().__init__()\n        self.encoder = None\n        self.decoder = None\n        self.image_size = image_size\n\n    def load_model(self, weight_path, device):\n        self.vqgan = torch.load(weight_path, map_location=device)\n\n    def get_codebook_indices(self, images):\n        _, _, [_, _, indices] = self.vqgan.encode(images)   # indices: [b, h//8, w//8]\n        return indices\n"
  },
  {
    "path": "models/modeling_finetune.py",
    "content": "import math\nimport numpy as np\nfrom functools import partial\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom furnace.utils import LP_BatchNorm\nfrom timm.models.layers import drop_path, to_2tuple, trunc_normal_\nfrom timm.models.registry import register_model\n\ndef _cfg(url='', **kwargs):\n    return {\n        'url': url,\n        'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,\n        'crop_pct': .9, 'interpolation': 'bicubic',\n        'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),\n        **kwargs\n    }\n\n\nclass DropPath(nn.Module):\n    \"\"\"Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).\n    \"\"\"\n    def __init__(self, drop_prob=None):\n        super(DropPath, self).__init__()\n        self.drop_prob = drop_prob\n\n    def forward(self, x):\n        return drop_path(x, self.drop_prob, self.training)\n    \n    def extra_repr(self) -> str:\n        return 'p={}'.format(self.drop_prob)\n\n\nclass Mlp(nn.Module):\n    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        self.fc1 = nn.Linear(in_features, hidden_features)\n        self.act = act_layer()\n        self.fc2 = nn.Linear(hidden_features, out_features)\n        self.drop = nn.Dropout(drop)\n\n    def forward(self, x):\n        x = self.fc1(x)\n        x = self.act(x)\n        # x = self.drop(x)\n        # commit this for the orignal BERT implement \n        x = self.fc2(x)\n        x = self.drop(x)\n        return x\n\nclass Attention(nn.Module):\n    def __init__(\n            self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,\n            proj_drop=0., window_size=None, attn_head_dim=None):\n        super().__init__()\n        self.num_heads = num_heads\n        head_dim = dim // num_heads\n        if attn_head_dim is not None:\n            head_dim = attn_head_dim\n        all_head_dim = head_dim * self.num_heads\n        self.scale = qk_scale or head_dim ** -0.5\n\n        self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)\n        if qkv_bias:\n            self.q_bias = nn.Parameter(torch.zeros(all_head_dim))\n            self.v_bias = nn.Parameter(torch.zeros(all_head_dim))\n        else:\n            self.q_bias = None\n            self.v_bias = None\n\n        if window_size:\n            self.window_size = window_size\n            self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3\n            self.relative_position_bias_table = nn.Parameter(\n                torch.zeros(self.num_relative_distance, num_heads))  # 2*Wh-1 * 2*Ww-1, nH\n            # cls to token & token 2 cls & cls to cls\n\n            # get pair-wise relative position index for each token inside the window\n            coords_h = torch.arange(window_size[0])\n            coords_w = torch.arange(window_size[1])\n            coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww\n            coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww\n            relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww\n            relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2\n            relative_coords[:, :, 0] += window_size[0] - 1  # shift to start from 0\n            relative_coords[:, :, 1] += window_size[1] - 1\n            relative_coords[:, :, 0] *= 2 * window_size[1] - 1\n            relative_position_index = \\\n                torch.zeros(size=(window_size[0] * window_size[1] + 1, ) * 2, dtype=relative_coords.dtype)\n            relative_position_index[1:, 1:] = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww\n            relative_position_index[0, 0:] = self.num_relative_distance - 3\n            relative_position_index[0:, 0] = self.num_relative_distance - 2\n            relative_position_index[0, 0] = self.num_relative_distance - 1\n\n            self.register_buffer(\"relative_position_index\", relative_position_index)\n        else:\n            self.window_size = None\n            self.relative_position_bias_table = None\n            self.relative_position_index = None\n\n        self.attn_drop = nn.Dropout(attn_drop)\n        self.proj = nn.Linear(all_head_dim, dim)\n        self.proj_drop = nn.Dropout(proj_drop)\n\n    def forward(self, x, rel_pos_bias=None):\n        B, N, C = x.shape\n        qkv_bias = None\n        if self.q_bias is not None:\n            qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))\n        # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)\n        qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)\n        qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)\n        q, k, v = qkv[0], qkv[1], qkv[2]   # make torchscript happy (cannot use tensor as tuple)\n\n        q = q * self.scale\n        attn = (q @ k.transpose(-2, -1))\n\n        if self.relative_position_bias_table is not None:\n            relative_position_bias = \\\n                self.relative_position_bias_table[self.relative_position_index.view(-1)].view(\n                    self.window_size[0] * self.window_size[1] + 1,\n                    self.window_size[0] * self.window_size[1] + 1, -1)  # Wh*Ww,Wh*Ww,nH\n            relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww\n            attn = attn + relative_position_bias.unsqueeze(0)\n\n        if rel_pos_bias is not None:\n            attn = attn + rel_pos_bias\n        \n        attn = attn.softmax(dim=-1)\n        attn = self.attn_drop(attn)\n\n        x = (attn @ v).transpose(1, 2).reshape(B, N, -1)\n        x = self.proj(x)\n        x = self.proj_drop(x)\n        return x\n\nclass CrossAttention(nn.Module):\n    def __init__(\n            self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,\n            proj_drop=0., window_size=None, attn_head_dim=None):\n        super().__init__()\n        self.num_heads = num_heads\n        head_dim = dim // num_heads\n        if attn_head_dim is not None:\n            head_dim = attn_head_dim\n        all_head_dim = head_dim * self.num_heads\n        self.scale = qk_scale or head_dim ** -0.5\n\n        self.q = nn.Linear(dim, all_head_dim, bias=False)\n        self.k = nn.Linear(dim, all_head_dim, bias=False)\n        self.v = nn.Linear(dim, all_head_dim, bias=False)\n\n        if qkv_bias:\n            self.q_bias = nn.Parameter(torch.zeros(all_head_dim))\n            self.v_bias = nn.Parameter(torch.zeros(all_head_dim))\n        else:\n            self.q_bias = None\n            self.k_bias = None\n            self.v_bias = None\n\n        self.attn_drop = nn.Dropout(attn_drop)\n        self.proj = nn.Linear(all_head_dim, dim)\n        self.proj_drop = nn.Dropout(proj_drop)\n\n    def forward(self, x, bool_masked_pos=None, k=None, v=None):\n        B, N, C = x.shape\n        N_k = k.shape[1]\n        N_v = v.shape[1]\n\n        q_bias, k_bias, v_bias = None, None, None\n        if self.q_bias is not None:\n            q_bias = self.q_bias\n            k_bias = torch.zeros_like(self.v_bias, requires_grad=False)\n            v_bias = self.v_bias\n\n        q = F.linear(input=x, weight=self.q.weight, bias=q_bias)\n        q = q.reshape(B, N, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0)    # (B, N_head, N_q, dim)\n\n        k = F.linear(input=k, weight=self.k.weight, bias=k_bias)\n        k = k.reshape(B, N_k, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0)\n\n        v = F.linear(input=v, weight=self.v.weight, bias=v_bias)   \n        v = v.reshape(B, N_v, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0)\n\n        q = q * self.scale\n        attn = (q @ k.transpose(-2, -1))      # (B, N_head, N_q, N_k)\n        \n        attn = attn.softmax(dim=-1)\n        attn = self.attn_drop(attn)\n\n        x = (attn @ v).transpose(1, 2).reshape(B, N, -1) \n        x = self.proj(x)\n        x = self.proj_drop(x)\n\n        return x\n\nclass Block(nn.Module):\n\n    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,\n                 drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,\n                 window_size=None, attn_head_dim=None):\n        super().__init__()\n        self.norm1 = norm_layer(dim)\n        self.attn = Attention(\n            dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,\n            attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim)\n        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here\n        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()\n        self.norm2 = norm_layer(dim)\n        mlp_hidden_dim = int(dim * mlp_ratio)\n        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)\n\n        if init_values > 0:\n            self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)\n            self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)\n        else:\n            self.gamma_1, self.gamma_2 = None, None\n\n    def forward(self, x, rel_pos_bias=None):\n        if self.gamma_1 is None:\n            x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))\n            x = x + self.drop_path(self.mlp(self.norm2(x)))\n        else:\n            x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))\n            x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))\n        return x\n\nclass AttentiveBlock(nn.Module):\n\n    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,\n                 drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,\n                 window_size=None, attn_head_dim=None):\n        super().__init__()\n\n        self.norm1_q = norm_layer(dim)\n        self.norm1_k = norm_layer(dim)\n        self.norm1_v = norm_layer(dim)\n        self.norm2_cross = norm_layer(dim)\n        self.cross_attn =  CrossAttention(\n            dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,\n            attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim)\n\n        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()\n        \n    def forward(self, x_q, x_kv, pos_q, pos_k, bool_masked_pos, rel_pos_bias=None):\n        x_q = self.norm1_q(x_q + pos_q)\n        x_k = self.norm1_k(x_kv + pos_k)\n        x_v = self.norm1_v(x_kv)\n\n        x = self.cross_attn(x_q, k=x_k, v=x_v)\n\n        return x\n\nclass PatchEmbed(nn.Module):\n    \"\"\" Image to Patch Embedding\n    \"\"\"\n    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):\n        super().__init__()\n        img_size = to_2tuple(img_size)\n        patch_size = to_2tuple(patch_size)\n        num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])\n        self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])\n        self.img_size = img_size\n        self.patch_size = patch_size\n        self.num_patches = num_patches\n\n        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)\n\n    def forward(self, x, **kwargs):\n        B, C, H, W = x.shape\n        # FIXME look at relaxing size constraints\n        assert H == self.img_size[0] and W == self.img_size[1], \\\n            f\"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).\"\n        x = self.proj(x).flatten(2).transpose(1, 2)\n        return x\n\n\nclass RelativePositionBias(nn.Module):\n\n    def __init__(self, window_size, num_heads):\n        super().__init__()\n        self.window_size = window_size\n        self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3\n        self.relative_position_bias_table = nn.Parameter(\n            torch.zeros(self.num_relative_distance, num_heads))  # 2*Wh-1 * 2*Ww-1, nH\n        # cls to token & token 2 cls & cls to cls\n\n        # get pair-wise relative position index for each token inside the window\n        coords_h = torch.arange(window_size[0])\n        coords_w = torch.arange(window_size[1])\n        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww\n        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww\n        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww\n        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2\n        relative_coords[:, :, 0] += window_size[0] - 1  # shift to start from 0\n        relative_coords[:, :, 1] += window_size[1] - 1\n        relative_coords[:, :, 0] *= 2 * window_size[1] - 1\n        relative_position_index = \\\n            torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)\n        relative_position_index[1:, 1:] = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww\n        relative_position_index[0, 0:] = self.num_relative_distance - 3\n        relative_position_index[0:, 0] = self.num_relative_distance - 2\n        relative_position_index[0, 0] = self.num_relative_distance - 1\n\n        self.register_buffer(\"relative_position_index\", relative_position_index)\n\n        # trunc_normal_(self.relative_position_bias_table, std=.02)\n\n    def forward(self):\n        relative_position_bias = \\\n            self.relative_position_bias_table[self.relative_position_index.view(-1)].view(\n                self.window_size[0] * self.window_size[1] + 1,\n                self.window_size[0] * self.window_size[1] + 1, -1)  # Wh*Ww,Wh*Ww,nH\n        return relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww\n\ndef get_sinusoid_encoding_table(n_position, d_hid, token=False):\n    ''' Sinusoid position encoding table '''\n\n    def get_position_angle_vec(position):\n        return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]\n\n    sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])\n    sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2])  # dim 2i\n    sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2])  # dim 2i+1\n\n    if token:\n        sinusoid_table = np.concatenate([sinusoid_table, np.zeros([1, d_hid])], dim=0)\n\n    return torch.FloatTensor(sinusoid_table).unsqueeze(0)\n\nclass VisionTransformer(nn.Module):\n    \"\"\" Vision Transformer with support for patch or hybrid CNN input stage\n    \"\"\"\n    def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,\n                 num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,\n                 drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None,\n                 use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False,\n                 use_mean_pooling=True, init_scale=0.001, lin_probe=False, linear_type='standard', args=None):\n        super().__init__()\n        self.num_classes = num_classes\n        self.num_features = self.embed_dim = embed_dim  # num_features for consistency with other models\n        self.use_mean_pooling = use_mean_pooling\n\n        self.patch_embed = PatchEmbed(\n            img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)\n        num_patches = self.patch_embed.num_patches\n\n        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))\n        # self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))\n        self.use_abs_pos_emb = use_abs_pos_emb\n        if use_abs_pos_emb:\n            self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))\n        elif args.sin_pos_emb:\n            # sine-cosine positional embeddings is on the way\n            self.pos_embed = self.build_2d_sincos_position_embedding(embed_dim)\n        else:\n            self.pos_embed = None\n\n        self.pos_drop = nn.Dropout(p=drop_rate)\n\n        if use_shared_rel_pos_bias:\n            self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads)\n        else:\n            self.rel_pos_bias = None\n\n        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]  # stochastic depth decay rule\n        self.use_rel_pos_bias = use_rel_pos_bias\n        self.blocks = nn.ModuleList([\n            Block(\n                dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,\n                drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,\n                init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None)\n            for i in range(depth)])\n        self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)\n\n        self.lin_probe = lin_probe\n        self.linear_type = linear_type\n\n        if lin_probe:\n            if self.linear_type == 'standard':\n                self.fc_norm = None\n            elif self.linear_type == 'attentive':\n                self.query_token = nn.Parameter(torch.zeros(1, 1, embed_dim))\n                self.attentive_blocks = nn.ModuleList([\n                    AttentiveBlock(\n                        dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,\n                        drop=drop_rate, attn_drop=attn_drop_rate, drop_path=0, norm_layer=norm_layer,\n                        init_values=0)\n                    for i in range(1)])\n                self.fc_norm = LP_BatchNorm(embed_dim, affine=False)\n        else:\n            if use_mean_pooling:\n                self.fc_norm = norm_layer(embed_dim)\n            else:\n                self.fc_norm = None\n\n        self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()\n\n        if self.pos_embed is not None and use_abs_pos_emb:\n            trunc_normal_(self.pos_embed, std=.02)\n        trunc_normal_(self.cls_token, std=.02)\n        trunc_normal_(self.head.weight, std=.02)\n        self.apply(self._init_weights)\n        self.fix_init_weight()\n\n        self.head.weight.data.mul_(init_scale)\n        self.head.bias.data.mul_(init_scale)\n\n    def build_2d_sincos_position_embedding(self, embed_dim=768, temperature=10000.):\n        h, w = self.patch_embed.patch_shape\n        grid_w = torch.arange(w, dtype=torch.float32)\n        grid_h = torch.arange(h, dtype=torch.float32)\n        grid_w, grid_h = torch.meshgrid(grid_w, grid_h)\n        assert embed_dim % 4 == 0, 'Embed dimension must be divisible by 4 for 2D sin-cos position embedding'\n        pos_dim = embed_dim // 4\n        omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim\n        omega = 1. / (temperature ** omega)\n        out_w = torch.einsum('m,d->md', [grid_w.flatten(), omega])\n        out_h = torch.einsum('m,d->md', [grid_h.flatten(), omega])\n        pos_emb = torch.cat([torch.sin(out_w), torch.cos(out_w), torch.sin(out_h), torch.cos(out_h)], dim=1)[None, :, :]\n\n        # if self.use_mean_pooling:\n        #     pos_embed = nn.Parameter(pos_emb)\n        # else:\n        pe_token = torch.zeros([1, 1, embed_dim], dtype=torch.float32)\n        pos_embed = nn.Parameter(torch.cat([pe_token, pos_emb], dim=1))\n        pos_embed.requires_grad = False\n        return pos_embed\n\n    def fix_init_weight(self):\n        def rescale(param, layer_id):\n            param.div_(math.sqrt(2.0 * layer_id))\n\n        for layer_id, layer in enumerate(self.blocks):\n            rescale(layer.attn.proj.weight.data, layer_id + 1)\n            rescale(layer.mlp.fc2.weight.data, layer_id + 1)\n\n    def _init_weights(self, m):\n        if isinstance(m, nn.Linear):\n            trunc_normal_(m.weight, std=.02)\n            if isinstance(m, nn.Linear) and m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n        elif isinstance(m, nn.LayerNorm):\n            nn.init.constant_(m.bias, 0)\n            nn.init.constant_(m.weight, 1.0)\n\n    def get_num_layers(self):\n        return len(self.blocks)\n\n    @torch.jit.ignore\n    def no_weight_decay(self):\n        return {'pos_embed', 'cls_token'}\n\n    def get_classifier(self):\n        return self.head\n\n    def reset_classifier(self, num_classes, global_pool=''):\n        self.num_classes = num_classes\n        self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()\n\n    def forward_features(self, x, is_train=True):\n        x = self.patch_embed(x)\n        batch_size, seq_len, _ = x.size()\n\n        cls_tokens = self.cls_token.expand(batch_size, -1, -1)  # stole cls_tokens impl from Phil Wang, thanks\n        x = torch.cat((cls_tokens, x), dim=1)\n        if self.pos_embed is not None:\n            if self.use_abs_pos_emb:\n                x = x + self.pos_embed.expand(batch_size, -1, -1).type_as(x).to(x.device).clone().detach()\n            else:\n                x = x + self.pos_embed.expand(batch_size, -1, -1).type_as(x).to(x.device).clone().detach()\n                \n        x = self.pos_drop(x)\n\n        rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None\n        for blk in self.blocks:\n            x = blk(x, rel_pos_bias=rel_pos_bias)\n\n        x = self.norm(x)\n\n        # linear probing or attentive probing\n        if self.lin_probe:\n            if self.linear_type == 'standard':\n                return x[:, 0]\n            else:\n                query_tokens = self.query_token.expand(batch_size, -1, -1)\n                for blk in self.attentive_blocks:\n                    query_tokens = blk(query_tokens, x, 0, 0, bool_masked_pos=None, rel_pos_bias=None)\n                return self.fc_norm(query_tokens[:, 0, :], is_train=is_train) \n        else:   # finetune\n            if self.fc_norm is not None:    # use mean pooling\n                t = x[:, 1:, :]\n                return self.fc_norm(t.mean(1))\n            else:\n                return x[:, 0]\n\n    def forward(self, x, is_train=True):\n        x = self.forward_features(x, is_train)\n        x = self.head(x)\n        return x\n\n\n@register_model\ndef cae_small_patch16_224(pretrained=False, **kwargs):\n    model = VisionTransformer(\n        patch_size=16, embed_dim=384, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)\n    model.default_cfg = _cfg()\n    return model\n\n@register_model\ndef cae_base_patch16_224(pretrained=False, **kwargs):\n    model = VisionTransformer(\n        patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)\n    model.default_cfg = _cfg()\n    return model\n\n\n@register_model\ndef cae_base_patch16_384(pretrained=False, **kwargs):\n    model = VisionTransformer(\n        img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)\n    model.default_cfg = _cfg()\n    return model\n\n\n@register_model\ndef cae_large_patch16_224(pretrained=False, **kwargs):\n    model = VisionTransformer(\n        patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)\n    model.default_cfg = _cfg()\n    return model\n\n\n@register_model\ndef cae_large_patch16_384(pretrained=False, **kwargs):\n    model = VisionTransformer(\n        img_size=384, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)\n    model.default_cfg = _cfg()\n    return model\n\n\n@register_model\ndef cae_large_patch16_512(pretrained=False, **kwargs):\n    model = VisionTransformer(\n        img_size=512, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)\n    model.default_cfg = _cfg()\n    return model\n"
  },
  {
    "path": "requirements.txt",
    "content": "torch==1.7.1\ntorchvision==0.8.2\ntimm==0.3.2\nPillow\nblobfile\nmypy\nnumpy\npytest\nrequests\neinops\ntensorboardX\ndeepspeed==0.4.0\nscipy\npytorch-lightning==1.0.8\nomegaconf==2.0.0\n"
  },
  {
    "path": "scripts/cae_base_800e.sh",
    "content": "tmp_my_name=${0##*/}\nmy_name=${tmp_my_name%.*}\n\nOUTPUT_DIR='./output/'$my_name\nDATA_PATH=/path/to/imagenet1k/train\nTOKENIZER_PATH=./tokenizer-weights\n\nADDRESS=ADDR_FOR_THIS_MACHINE                                                                                 \nNNODES=4     \nRANK=RANK_FOR_THIS_MACHINE                                                                                                                        \n\n# ============================ pretraining ============================\nOMP_NUM_THREADS=1 python -m torch.distributed.launch \\\n  --nproc_per_node=8 \\\n  --nnodes=$NNODES \\\n  --node_rank=$RANK \\\n  --master_addr=$ADDRESS \\\n  --master_port=8899 \\\n  tools/run_pretraining.py \\\n  --data_path ${DATA_PATH} \\\n  --output_dir ${OUTPUT_DIR} \\\n  --model cae_base_patch16_224_8k_vocab --discrete_vae_weight_path ${TOKENIZER_PATH} \\\n  --batch_size 64 --lr 1.5e-3 --warmup_epochs 20 --epochs 800 \\\n  --clip_grad 3.0 --layer_scale_init_value 0.1 \\\n  --imagenet_default_mean_and_std \\\n  --color_jitter 0 \\\n  --drop_path 0.1 \\\n  --sincos_pos_emb \\\n  --mask_generator block \\\n  --num_mask_patches 98 \\\n  --decoder_layer_scale_init_value 0.1 \\\n  --no_auto_resume \\\n  --save_ckpt_freq 100 \\\n  --exp_name $my_name \\\n  --regressor_depth 4 \\\n  --decoder_depth 4 \\\n  --align_loss_weight 2\n\n\n# ============================ linear probing ============================\nDATA_PATH=/path/to/imagenet1k/\nMODEL_PATH=/path/to/pretrained/model\n\nOMP_NUM_THREADS=1 python -m torch.distributed.launch \\\n    --nproc_per_node=8 \\\n    --nnodes=$NNODES \\\n    --node_rank=$RANK \\\n    --master_addr=$ADDRESS \\\n    --master_port=8899 \\\n    tools/run_linear.py \\\n    --model cae_base_patch16_224 --data_path $DATA_PATH \\\n    --finetune $MODEL_PATH \\\n    --nb_classes 1000 \\\n    --batch_size 512 \\\n    --epochs 90 \\\n    --blr 0.1 \\\n    --weight_decay 0.0 \\\n    --dist_eval --data_path ${DATA_PATH} \\\n    --output_dir $OUTPUT_DIR \\\n    --log_dir $OUTPUT_DIR \\\n    --enable_linear_eval \\\n    --use_cls \\\n    --dist_eval \\\n    --save_freq 50 \\\n    --disable_rel_pos_bias \\\n    --linear_type standard \\\n    --exp_name $my_name\n\n# ============================ attentive probing ============================\nDATA_PATH=/path/to/imagenet1k/\nMODEL_PATH=/path/to/pretrained/model\n\nOMP_NUM_THREADS=1 python -m torch.distributed.launch \\\n    --nproc_per_node=8 \\\n    --nnodes=$NNODES \\\n    --node_rank=$RANK \\\n    --master_addr=$ADDRESS \\\n    --master_port=8899 \\\n    tools/run_attentive.py \\\n    --model cae_base_patch16_224 --data_path $DATA_PATH \\\n    --finetune $MODEL_PATH \\\n    --nb_classes 1000 --data_set IMNET --imagenet_default_mean_and_std \\\n    --output_dir $OUTPUT_DIR --batch_size 256 --lr 0.4 --update_freq 1 \\\n    --warmup_epochs 10 --epochs 90 \\\n    --weight_decay 0 --smoothing 0.0 --layer_decay 1.0 --drop_path 0.0 \\\n    --color_jitter 0.0 --mixup 0.0 --cutmix 0.0 --reprob 0.0 \\\n    --opt sgd --momentum 0.9 \\\n    --enable_linear_eval \\\n    --use_cls \\\n    --dist_eval \\\n    --no_auto_resume \\\n    --save_ckpt_freq 50 \\\n    --linear_type attentive \\\n    --exp_name $my_name\n"
  },
  {
    "path": "scripts/cae_base_finetune.sh",
    "content": "tmp_my_name=${0##*/}\nmy_name=${tmp_my_name%.*}\n\nOUTPUT_DIR='./output/'$my_name\nDATA_PATH=/path/to/imagenet1k/train\nTOKENIZER_PATH=./tokenizer-weights\n\nADDRESS=ADDR_FOR_THIS_MACHINE                                                                                 \nNNODES=4     \nRANK=RANK_FOR_THIS_MACHINE                                                                                                                        \n\nMODEL_PATH=/path/to/pretrained/model\n\nOMP_NUM_THREADS=1 python -m torch.distributed.launch \\\n    --nproc_per_node=8 \\\n    --nnodes=$NNODES \\\n    --node_rank=$RANK \\\n    --master_addr=$ADDRESS \\\n    --master_port=8899 \\\n    tools/run_class_finetuning.py \\\n    --model cae_base_patch16_224  --data_path $DATA_PATH \\\n    --finetune $MODEL_PATH \\\n    --nb_classes 1000 --data_set IMNET \\\n    --output_dir $OUTPUT_DIR \\\n    --batch_size 128 \\\n    --lr 8e-3 --update_freq 1 \\\n    --warmup_epochs 5 --epochs 100 --layer_decay 0.65 --drop_path 0.1 \\\n    --weight_decay 0.05 --mixup 0.8 --cutmix 1.0 \\\n\t--sin_pos_emb \\\n    --dist_eval \\\n    --no_auto_resume \\\n    --exp_name $my_name \\\n    --imagenet_default_mean_and_std\n\n\n\n\n"
  },
  {
    "path": "scripts/cae_large_1600e.sh",
    "content": "tmp_my_name=${0##*/}\nmy_name=${tmp_my_name%.*}\n\nOUTPUT_DIR='./output/'$my_name\nDATA_PATH=/path/to/imagenet1k/train\nTOKENIZER_PATH=./tokenizer-weights\n\nADDRESS=ADDR_FOR_THIS_MACHINE                                                                                 \nNNODES=4     \nRANK=RANK_FOR_THIS_MACHINE                                                                                                                        \n\n\n# ============================ pretraining ============================\nOMP_NUM_THREADS=1 python -m torch.distributed.launch \\\n  --nproc_per_node=8 \\\n  --nnodes=$NNODES \\\n  --node_rank=$RANK \\\n  --master_addr=$ADDRESS \\\n  --master_port=8899 \\\n  tools/run_pretraining.py \\\n  --data_path ${DATA_PATH} \\\n  --output_dir ${OUTPUT_DIR} \\\n  --model cae_large_patch16_224_8k_vocab --discrete_vae_weight_path ${TOKENIZER_PATH} \\\n  --batch_size 64 --lr 1.5e-3 --warmup_epochs 40 --epochs 1600 \\\n  --clip_grad 3.0 --layer_scale_init_value 1e-5 \\\n  --imagenet_default_mean_and_std \\\n  --color_jitter 0 \\\n  --drop_path 0.1 \\\n  --sincos_pos_emb \\\n  --mask_generator block \\\n  --num_mask_patches 98 \\\n  --decoder_layer_scale_init_value 1e-5 \\\n  --no_auto_resume \\\n  --save_ckpt_freq 100 \\\n  --exp_name $my_name \\\n  --regressor_depth 4 \\\n  --decoder_depth 4 \\\n  --align_loss_weight 2\n  --decoder_embed_dim 1024 \\\n  --decoder_num_heads 16 \\\n  --fix_init_weight\n\n# ============================ linear probing ============================\nDATA_PATH=/path/to/imagenet1k/\nMODEL_PATH=/path/to/pretrained/model\n\nOMP_NUM_THREADS=1 python -m torch.distributed.launch \\\n    --nproc_per_node=8 \\\n    --nnodes=$NNODES \\\n    --node_rank=$RANK \\\n    --master_addr=$ADDRESS \\\n    --master_port=8899 \\\n    tools/run_linear.py \\\n    --model cae_large_patch16_224 --data_path $DATA_PATH \\\n    --finetune $MODEL_PATH \\\n    --nb_classes 1000 \\\n    --batch_size 512 \\\n    --epochs 90 \\\n    --blr 0.1 \\\n    --weight_decay 0.0 \\\n    --dist_eval --data_path ${DATA_PATH} \\\n    --output_dir $OUTPUT_DIR \\\n    --log_dir $OUTPUT_DIR \\\n    --enable_linear_eval \\\n    --use_cls \\\n    --dist_eval \\\n    --save_freq 90 \\\n    --disable_rel_pos_bias \\\n    --linear_type standard \\\n    --exp_name $my_name\n\n\n# ============================ attentive probing ============================\nDATA_PATH=/path/to/imagenet1k/\nMODEL_PATH=/path/to/pretrained/model\n\nOMP_NUM_THREADS=1 python -m torch.distributed.launch \\\n    --nproc_per_node=8 \\\n    --nnodes=$NNODES \\\n    --node_rank=$RANK \\\n    --master_addr=$ADDRESS \\\n    --master_port=8899 \\\n    tools/run_attentive.py \\\n    --model cae_large_patch16_224 --data_path $DATA_PATH \\\n    --finetune $MODEL_PATH \\\n    --nb_classes 1000 --data_set IMNET --imagenet_default_mean_and_std \\\n    --output_dir $OUTPUT_DIR --batch_size 256 --lr 0.4 --update_freq 1 \\\n    --warmup_epochs 10 --epochs 90 \\\n    --weight_decay 0 --smoothing 0.0 --layer_decay 1.0 --drop_path 0.0 \\\n    --color_jitter 0.0 --mixup 0.0 --cutmix 0.0 --reprob 0.0 \\\n    --opt sgd --momentum 0.9 \\\n    --enable_linear_eval \\\n    --use_cls \\\n    --dist_eval \\\n    --no_auto_resume \\\n    --save_ckpt_freq 50 \\\n    --linear_type attentive \\\n    --exp_name $my_name\n"
  },
  {
    "path": "scripts/cae_large_finetune.sh",
    "content": "tmp_my_name=${0##*/}\nmy_name=${tmp_my_name%.*}\n\nOUTPUT_DIR='./output/'$my_name\nDATA_PATH=/path/to/imagenet1k/train\nTOKENIZER_PATH=./tokenizer-weights\n\nADDRESS=ADDR_FOR_THIS_MACHINE                                                                                 \nNNODES=4     \nRANK=RANK_FOR_THIS_MACHINE                                                                                                                        \n\nMODEL_PATH=/path/to/pretrained/model\n\n\nOMP_NUM_THREADS=1 python -m torch.distributed.launch \\\n    --nproc_per_node=8 \\\n    --nnodes=$NNODES \\\n    --node_rank=$RANK \\\n    --master_addr=$ADDRESS \\\n    --master_port=8899 \\\n    tools/run_class_finetuning.py \\\n    --model cae_large_patch16_224  --data_path $DATA_PATH \\\n    --finetune $MODEL_PATH \\\n    --nb_classes 1000 --data_set IMNET \\\n    --output_dir $OUTPUT_DIR \\\n    --batch_size 64 \\\n    --lr 2e-3 --update_freq 2 \\\n    --warmup_epochs 5 --epochs 50 --layer_decay 0.75 --drop_path 0.2 \\\n    --weight_decay 0.05 --mixup 0.8 --cutmix 1.0 \\\n\t--sin_pos_emb \\\n    --dist_eval \\\n    --no_auto_resume \\\n    --exp_name $my_name \\\n    --imagenet_default_mean_and_std\n\n\n\n\n"
  },
  {
    "path": "tokenizer-weights/README",
    "content": "-- tokenizers"
  },
  {
    "path": "tools/run_attentive.py",
    "content": "import argparse\nimport datetime\nimport numpy as np\nimport time\nimport torch\nimport torch.backends.cudnn as cudnn\nimport json\nimport os\n\nfrom pathlib import Path\n\nfrom timm.data.mixup import Mixup\nfrom timm.models import create_model\nfrom timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy\nfrom timm.utils import ModelEma\nfrom furnace.optim_factory import create_optimizer, get_parameter_groups, LayerDecayValueAssigner\n\nfrom furnace.datasets import build_dataset\nfrom furnace.engine_for_finetuning import train_one_epoch, evaluate\nfrom furnace.utils import NativeScalerWithGradNormCount as NativeScaler\nimport furnace.utils as utils\nfrom scipy import interpolate\nimport models.modeling_finetune\n\n\ndef get_args():\n    parser = argparse.ArgumentParser('fine-tuning and evaluation script for image classification', add_help=False)\n    parser.add_argument('--batch_size', default=64, type=int)\n    parser.add_argument('--epochs', default=30, type=int)\n    parser.add_argument('--update_freq', default=1, type=int)\n    parser.add_argument('--save_ckpt_freq', default=5, type=int)\n\n    # Model parameters\n    parser.add_argument('--model', default='deit_base_patch16_224', type=str, metavar='MODEL',\n                        help='Name of model to train')\n    parser.add_argument('--rel_pos_bias', action='store_true')\n    parser.add_argument('--disable_rel_pos_bias', action='store_false', dest='rel_pos_bias')\n    parser.set_defaults(rel_pos_bias=True)\n    parser.add_argument('--abs_pos_emb', action='store_true')\n    parser.set_defaults(abs_pos_emb=False)\n    parser.add_argument('--sin_pos_emb', action='store_true')\n    parser.set_defaults(sin_pos_emb=True)\n    parser.add_argument('--disable_sin_pos_emb', action='store_false', dest='sin_pos_emb')\n\n    parser.add_argument('--layer_scale_init_value', default=0.1, type=float, \n                        help=\"0.1 for base, 1e-5 for large. set 0 to disable layer scale\")\n\n    parser.add_argument('--input_size', default=224, type=int,\n                        help='images input size')\n\n    parser.add_argument('--drop', type=float, default=0.0, metavar='PCT',\n                        help='Dropout rate (default: 0.)')\n    parser.add_argument('--attn_drop_rate', type=float, default=0.0, metavar='PCT',\n                        help='Attention dropout rate (default: 0.)')\n    parser.add_argument('--drop_path', type=float, default=0.1, metavar='PCT',\n                        help='Drop path rate (default: 0.1)')\n\n    parser.add_argument('--disable_eval_during_finetuning', action='store_true', default=False)\n\n    parser.add_argument('--model_ema', action='store_true', default=False)\n    parser.add_argument('--model_ema_decay', type=float, default=0.9999, help='')\n    parser.add_argument('--model_ema_force_cpu', action='store_true', default=False, help='')\n\n    # Optimizer parameters\n    parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',\n                        help='Optimizer (default: \"adamw\"')\n    parser.add_argument('--opt_eps', default=1e-8, type=float, metavar='EPSILON',\n                        help='Optimizer Epsilon (default: 1e-8)')\n    parser.add_argument('--opt_betas', default=None, type=float, nargs='+', metavar='BETA',\n                        help='Optimizer Betas (default: None, use opt default)')\n    parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM',\n                        help='Clip gradient norm (default: None, no clipping)')\n    parser.add_argument('--momentum', type=float, default=0.9, metavar='M',\n                        help='SGD momentum (default: 0.9)')\n    parser.add_argument('--weight_decay', type=float, default=0.05,\n                        help='weight decay (default: 0.05)')\n    parser.add_argument('--weight_decay_end', type=float, default=None, help=\"\"\"Final value of the\n        weight decay. We use a cosine schedule for WD and using a larger decay by\n        the end of training improves performance for ViTs.\"\"\")\n\n    parser.add_argument('--lr', type=float, default=5e-4, metavar='LR',\n                        help='learning rate (default: 5e-4)')\n    parser.add_argument('--layer_decay', type=float, default=0.9)\n\n    parser.add_argument('--warmup_lr', type=float, default=1e-6, metavar='LR',\n                        help='warmup learning rate (default: 1e-6)')\n    parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR',\n                        help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')\n\n    parser.add_argument('--warmup_epochs', type=int, default=5, metavar='N',\n                        help='epochs to warmup LR, if scheduler supports')\n    parser.add_argument('--warmup_steps', type=int, default=-1, metavar='N',\n                        help='num of steps to warmup LR, will overload warmup_epochs if set > 0')\n\n    # Augmentation parameters\n    parser.add_argument('--color_jitter', type=float, default=0.4, metavar='PCT',\n                        help='Color jitter factor (default: 0.4)')\n    parser.add_argument('--aa', type=str, default='', metavar='NAME',\n                        help='Use AutoAugment policy. \"v0\" or \"original\". \" + \"(default: rand-m9-mstd0.5-inc1)'),\n    parser.add_argument('--smoothing', type=float, default=0.1,\n                        help='Label smoothing (default: 0.1)')\n    parser.add_argument('--train_interpolation', type=str, default='bicubic',\n                        help='Training interpolation (random, bilinear, bicubic default: \"bicubic\")')\n\n    # Evaluation parameters\n    parser.add_argument('--crop_pct', type=float, default=None)\n\n    # * Random Erase params\n    parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',\n                        help='Random erase prob (default: 0.25)')\n    parser.add_argument('--remode', type=str, default='pixel',\n                        help='Random erase mode (default: \"pixel\")')\n    parser.add_argument('--recount', type=int, default=1,\n                        help='Random erase count (default: 1)')\n    parser.add_argument('--resplit', action='store_true', default=False,\n                        help='Do not random erase first (clean) augmentation split')\n\n    # * Mixup params\n    parser.add_argument('--mixup', type=float, default=0,\n                        help='mixup alpha, mixup enabled if > 0.')\n    parser.add_argument('--cutmix', type=float, default=0,\n                        help='cutmix alpha, cutmix enabled if > 0.')\n    parser.add_argument('--cutmix_minmax', type=float, nargs='+', default=None,\n                        help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')\n    parser.add_argument('--mixup_prob', type=float, default=1.0,\n                        help='Probability of performing mixup or cutmix when either/both is enabled')\n    parser.add_argument('--mixup_switch_prob', type=float, default=0.5,\n                        help='Probability of switching to cutmix when both mixup and cutmix enabled')\n    parser.add_argument('--mixup_mode', type=str, default='batch',\n                        help='How to apply mixup/cutmix params. Per \"batch\", \"pair\", or \"elem\"')\n\n    # * Finetuning params\n    parser.add_argument('--finetune', default='',\n                        help='finetune from checkpoint')\n    parser.add_argument('--model_key', default='model|module|state_dict', type=str)\n    parser.add_argument('--model_prefix', default='', type=str)\n    parser.add_argument('--init_scale', default=0.001, type=float)\n    parser.add_argument('--use_mean_pooling', action='store_true')\n    parser.set_defaults(use_mean_pooling=True)\n    parser.add_argument('--use_cls', action='store_false', dest='use_mean_pooling')\n    parser.add_argument('--disable_weight_decay_on_rel_pos_bias', action='store_true', default=False)\n\n    # Dataset parameters\n    parser.add_argument('--data_path', default='/datasets01/imagenet_full_size/061417/', type=str,\n                        help='dataset path')\n    parser.add_argument('--eval_data_path', default=None, type=str,\n                        help='dataset path for evaluation')\n    parser.add_argument('--nb_classes', default=0, type=int,\n                        help='number of the classification types')\n    parser.add_argument('--imagenet_default_mean_and_std', default=False, action='store_true')\n\n    parser.add_argument('--data_set', default='IMNET', choices=['CIFAR', 'IMNET', 'IMNET100', 'image_folder'],\n                        type=str, help='ImageNet dataset path')\n    parser.add_argument('--output_dir', default='',\n                        help='path where to save, empty for no saving')\n    parser.add_argument('--log_dir', default=None,\n                        help='path where to tensorboard log')\n    parser.add_argument('--device', default='cuda',\n                        help='device to use for training / testing')\n    parser.add_argument('--seed', default=0, type=int)\n    parser.add_argument('--resume', default='',\n                        help='resume from checkpoint')\n    parser.add_argument('--auto_resume', action='store_true')\n    parser.add_argument('--no_auto_resume', action='store_false', dest='auto_resume')\n    parser.set_defaults(auto_resume=True)\n\n    parser.add_argument('--save_ckpt', action='store_true')\n    parser.add_argument('--no_save_ckpt', action='store_false', dest='save_ckpt')\n    parser.set_defaults(save_ckpt=True)\n\n    parser.add_argument('--start_epoch', default=0, type=int, metavar='N',\n                        help='start epoch')\n    parser.add_argument('--eval', action='store_true',\n                        help='Perform evaluation only')\n    parser.add_argument('--dist_eval', action='store_true', default=False,\n                        help='Enabling distributed evaluation')\n    parser.add_argument('--num_workers', default=10, type=int)\n    parser.add_argument('--pin_mem', action='store_true',\n                        help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')\n    parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')\n    parser.set_defaults(pin_mem=True)\n\n    # distributed training parameters\n    parser.add_argument('--world_size', default=1, type=int,\n                        help='number of distributed processes')\n    parser.add_argument('--local_rank', default=-1, type=int)\n    parser.add_argument('--dist_on_itp', action='store_true')\n    parser.add_argument('--dist_url', default='env://',\n                        help='url used to set up distributed training')\n\n    parser.add_argument('--enable_deepspeed', action='store_true', default=False)\n    parser.add_argument('--enable_linear_eval', action='store_true', default=False)\n    parser.add_argument('--enable_multi_print', action='store_true',default=False, help='allow each gpu prints something')\n\n    parser.add_argument('--linear_type', default='standard', type=str, help='standard, attentive')\n    parser.add_argument('--exp_name', default='', type=str,\n                    help='name of exp. it is helpful when save the checkpoint')\n\n    known_args, _ = parser.parse_known_args()\n\n    if known_args.enable_deepspeed:\n        try:\n            import deepspeed\n            from deepspeed import DeepSpeedConfig\n            parser = deepspeed.add_config_arguments(parser)\n            ds_init = deepspeed.initialize\n        except:\n            print(\"Please 'pip install deepspeed==0.4.0'\")\n            exit(0)\n    else:\n        ds_init = None\n\n    return parser.parse_args(), ds_init\n\n\ndef main(args, ds_init):\n\n    if not args.enable_linear_eval:\n        args.aa = 'rand-m9-mstd0.5-inc1'\n\n    utils.init_distributed_mode(args)\n\n    if ds_init is not None:\n        utils.create_ds_config(args)\n\n    print(args)\n\n    device = torch.device(args.device)\n\n    # fix the seed for reproducibility\n    seed = args.seed + utils.get_rank()\n    torch.manual_seed(seed)\n    np.random.seed(seed)\n    # random.seed(seed)\n\n    cudnn.benchmark = True\n\n    dataset_train, args.nb_classes = build_dataset(is_train=True, args=args)\n    if args.disable_eval_during_finetuning:\n        dataset_val = None\n    else:\n        dataset_val, _ = build_dataset(is_train=False, args=args)\n\n    if True:  # args.distributed:\n        num_tasks = utils.get_world_size()\n        global_rank = utils.get_rank()\n        sampler_train = torch.utils.data.DistributedSampler(\n            dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True\n        )\n        print(\"Sampler_train = %s\" % str(sampler_train))\n        if args.dist_eval:\n            if len(dataset_val) % num_tasks != 0:\n                print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '\n                      'This will slightly alter validation results as extra duplicate entries are added to achieve '\n                      'equal num of samples per-process.')\n            sampler_val = torch.utils.data.DistributedSampler(\n                dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False)\n        else:\n            sampler_val = torch.utils.data.SequentialSampler(dataset_val)\n    else:\n        sampler_train = torch.utils.data.RandomSampler(dataset_train)\n        sampler_val = torch.utils.data.SequentialSampler(dataset_val)\n\n    if global_rank == 0 and args.log_dir is not None:\n        os.makedirs(args.log_dir, exist_ok=True)\n        log_writer = utils.TensorboardLogger(log_dir=args.log_dir)\n    else:\n        log_writer = None\n\n    data_loader_train = torch.utils.data.DataLoader(\n        dataset_train, sampler=sampler_train,\n        batch_size=args.batch_size,\n        num_workers=args.num_workers,\n        pin_memory=args.pin_mem,\n        drop_last=True,\n    )\n\n    if dataset_val is not None:\n        data_loader_val = torch.utils.data.DataLoader(\n            dataset_val, sampler=sampler_val,\n            batch_size=int(1.5 * args.batch_size),\n            num_workers=args.num_workers,\n            pin_memory=args.pin_mem,\n            drop_last=False\n        )\n    else:\n        data_loader_val = None\n\n    mixup_fn = None\n    mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None\n    if mixup_active:\n        print(\"Mixup is activated!\")\n        mixup_fn = Mixup(\n            mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,\n            prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,\n            label_smoothing=args.smoothing, num_classes=args.nb_classes)\n\n    model = create_model(\n        args.model,\n        pretrained=False,\n        num_classes=args.nb_classes,\n        drop_rate=args.drop,\n        drop_path_rate=args.drop_path,\n        attn_drop_rate=args.attn_drop_rate,\n        drop_block_rate=None,\n        use_mean_pooling=args.use_mean_pooling,\n        init_scale=args.init_scale,\n        use_rel_pos_bias=args.rel_pos_bias,\n        use_abs_pos_emb=args.abs_pos_emb,\n        init_values=args.layer_scale_init_value,\n        lin_probe=args.enable_linear_eval,\n        linear_type=args.linear_type,\n        args=args,\n    )\n\n    if args.enable_linear_eval:\n        linear_keyword = 'head'\n        head_norm = 'fc_norm'\n        parameters_requires_grad = []\n        for name, param in model.named_parameters():\n            param.requires_grad = False # no grad by default\n            if 'gamma' in name:\n                param.requires_grad = False\n            else:\n                if ('query_token' in name) or ('attentive_blocks' in name) or (name in ['%s.weight' % linear_keyword, '%s.bias' % linear_keyword]) or (head_norm in name):\n                    parameters_requires_grad.append(name)\n                    param.requires_grad = True\n\n        print(f'parameters that need grad: ', parameters_requires_grad)\n        getattr(model, linear_keyword).weight.data.normal_(mean=0.0, std=0.01)\n        getattr(model, linear_keyword).bias.data.zero_()\n\n    patch_size = model.patch_embed.patch_size\n    print(\"Patch size = %s\" % str(patch_size))\n    args.window_size = (args.input_size // patch_size[0], args.input_size // patch_size[1])\n    args.patch_size = patch_size\n\n    if args.finetune:\n        if args.finetune.startswith('https'):\n            checkpoint = torch.hub.load_state_dict_from_url(\n                args.finetune, map_location='cpu', check_hash=True)\n        else:\n            checkpoint = torch.load(args.finetune, map_location='cpu')\n\n        print(\"Load ckpt from %s\" % args.finetune)\n        checkpoint_model = None\n        for model_key in args.model_key.split('|'):\n            if model_key in checkpoint:\n                checkpoint_model = checkpoint[model_key]\n                print(\"Load state_dict by model_key = %s\" % model_key)\n                break\n        if checkpoint_model is None:\n            checkpoint_model = checkpoint\n        state_dict = model.state_dict()\n        original_all_keys = list(checkpoint_model.keys())\n        print(\"##########origin keys:\", len(original_all_keys), original_all_keys)\n        # NOTE: remove all decoder keys\n        all_keys = [key for key in original_all_keys if key.startswith('encoder.')]\n        print(\"all keys:\", all_keys)\n        for key in all_keys:\n            new_key = key.replace('encoder.','')\n            # print(\"new_key:\", new_key)\n            checkpoint_model[new_key] = checkpoint_model[key]\n            checkpoint_model.pop(key)\n\n\n        # handle moco-v3 checkpoints\n        all_keys = [key for key in original_all_keys if key.startswith('module.base_encoder.')]\n        print(\"all keys:\", all_keys)\n        for key in all_keys:\n            new_key = key.replace('module.base_encoder.','')\n            # print(\"new_key:\", new_key)\n            checkpoint_model[new_key] = checkpoint_model[key]\n            checkpoint_model.pop(key)\n            \n        for key in list(checkpoint_model.keys()):\n            if key.startswith('decoder.'):\n                # print(\"key:\", key)\n                checkpoint_model.pop(key)\n            if key.startswith('teacher.'):\n                # print(\"key:\", key)\n                checkpoint_model.pop(key)\n\n        for k in ['head.weight', 'head.bias']:\n            if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:\n                print(f\"Removing key {k} from pretrained checkpoint\")\n                del checkpoint_model[k]\n                \n        if model.use_rel_pos_bias and \"rel_pos_bias.relative_position_bias_table\" in checkpoint_model:\n            print(\"Expand the shared relative position embedding to each transformer block. \")\n            num_layers = model.get_num_layers()\n            rel_pos_bias = checkpoint_model[\"rel_pos_bias.relative_position_bias_table\"]\n            for i in range(num_layers):\n                checkpoint_model[\"blocks.%d.attn.relative_position_bias_table\" % i] = rel_pos_bias.clone()\n\n            checkpoint_model.pop(\"rel_pos_bias.relative_position_bias_table\")\n\n        all_keys = list(checkpoint_model.keys())\n\n        for key in all_keys:\n            if \"relative_position_index\" in key:\n                checkpoint_model.pop(key)\n\n            if \"relative_position_bias_table\" in key and args.rel_pos_bias:\n                rel_pos_bias = checkpoint_model[key]\n                src_num_pos, num_attn_heads = rel_pos_bias.size()\n                dst_num_pos, _ = model.state_dict()[key].size()\n                dst_patch_shape = model.patch_embed.patch_shape\n                if dst_patch_shape[0] != dst_patch_shape[1]:\n                    raise NotImplementedError()\n                num_extra_tokens = dst_num_pos - (dst_patch_shape[0] * 2 - 1) * (dst_patch_shape[1] * 2 - 1)\n                src_size = int((src_num_pos - num_extra_tokens) ** 0.5)\n                dst_size = int((dst_num_pos - num_extra_tokens) ** 0.5)\n                if src_size != dst_size:\n                    print(\"Position interpolate for %s from %dx%d to %dx%d\" % (\n                        key, src_size, src_size, dst_size, dst_size))\n                    extra_tokens = rel_pos_bias[-num_extra_tokens:, :]\n                    rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :]\n\n                    def geometric_progression(a, r, n):\n                        return a * (1.0 - r ** n) / (1.0 - r)\n\n                    left, right = 1.01, 1.5\n                    while right - left > 1e-6:\n                        q = (left + right) / 2.0\n                        gp = geometric_progression(1, q, src_size // 2)\n                        if gp > dst_size // 2:\n                            right = q\n                        else:\n                            left = q\n\n                    # if q > 1.090307:\n                    #     q = 1.090307\n\n                    dis = []\n                    cur = 1\n                    for i in range(src_size // 2):\n                        dis.append(cur)\n                        cur += q ** (i + 1)\n\n                    r_ids = [-_ for _ in reversed(dis)]\n\n                    x = r_ids + [0] + dis\n                    y = r_ids + [0] + dis\n\n                    t = dst_size // 2.0\n                    dx = np.arange(-t, t + 0.1, 1.0)\n                    dy = np.arange(-t, t + 0.1, 1.0)\n\n                    print(\"Original positions = %s\" % str(x))\n                    print(\"Target positions = %s\" % str(dx))\n\n                    all_rel_pos_bias = []\n\n                    for i in range(num_attn_heads):\n                        z = rel_pos_bias[:, i].view(src_size, src_size).float().numpy()\n                        f = interpolate.interp2d(x, y, z, kind='cubic')\n                        all_rel_pos_bias.append(\n                            torch.Tensor(f(dx, dy)).contiguous().view(-1, 1).to(rel_pos_bias.device))\n\n                    rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1)\n\n                    new_rel_pos_bias = torch.cat((rel_pos_bias, extra_tokens), dim=0)\n                    checkpoint_model[key] = new_rel_pos_bias\n\n        print(\"##############new keys:\", len(checkpoint_model), checkpoint_model.keys())\n        #print(\"##############model:\", model)\n\n        # interpolate position embedding\n        if 'pos_embed' in checkpoint_model and args.abs_pos_emb:\n            pos_embed_checkpoint = checkpoint_model['pos_embed']\n            embedding_size = pos_embed_checkpoint.shape[-1]\n            num_patches = model.patch_embed.num_patches\n            num_extra_tokens = model.pos_embed.shape[-2] - num_patches\n            # height (== width) for the checkpoint position embedding\n            orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)\n            # height (== width) for the new position embedding\n            new_size = int(num_patches ** 0.5)\n            # class_token and dist_token are kept unchanged\n            if orig_size != new_size:\n                print(\"Position interpolate from %dx%d to %dx%d\" % (orig_size, orig_size, new_size, new_size))\n                extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]\n                # only the position tokens are interpolated\n                pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]\n                pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)\n                pos_tokens = torch.nn.functional.interpolate(\n                    pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)\n                pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)\n                new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)\n                checkpoint_model['pos_embed'] = new_pos_embed\n\n        utils.load_state_dict(model, checkpoint_model, prefix=args.model_prefix)\n        # model.load_state_dict(checkpoint_model, strict=False)\n\n    model.to(device)\n\n    model_ema = None\n    if args.model_ema:\n        # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper\n        model_ema = ModelEma(\n            model,\n            decay=args.model_ema_decay,\n            device='cpu' if args.model_ema_force_cpu else '',\n            resume='')\n        print(\"Using EMA with decay = %.8f\" % args.model_ema_decay)\n\n    model_without_ddp = model\n    n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)\n\n    print(\"Model = %s\" % str(model_without_ddp))\n    print('number of params:', n_parameters)\n\n    total_batch_size = args.batch_size * args.update_freq * utils.get_world_size()\n    num_training_steps_per_epoch = len(dataset_train) // total_batch_size\n    \n    print(\"LR = %.8f\" % args.lr)\n    print(\"Batch size = %d\" % total_batch_size)\n    print(\"Update frequent = %d\" % args.update_freq)\n    print(\"Number of training examples = %d\" % len(dataset_train))\n    print(\"Number of training training per epoch = %d\" % num_training_steps_per_epoch)\n\n    num_layers = model_without_ddp.get_num_layers()\n    if args.layer_decay < 1.0:\n        assigner = LayerDecayValueAssigner(list(args.layer_decay ** (num_layers + 1 - i) for i in range(num_layers + 2)))\n    else:\n        assigner = None\n\n    if assigner is not None:\n        print(\"Assigned values = %s\" % str(assigner.values))\n\n    skip_weight_decay_list = model.no_weight_decay()\n    print(\"Skip weight decay list: \", skip_weight_decay_list)\n\n    if args.disable_weight_decay_on_rel_pos_bias:\n        for i in range(num_layers):\n            skip_weight_decay_list.add(\"blocks.%d.attn.relative_position_bias_table\" % i)\n\n    if args.enable_deepspeed:\n        loss_scaler = None\n        optimizer_params = get_parameter_groups(\n            model, args.weight_decay, skip_weight_decay_list,\n            assigner.get_layer_id if assigner is not None else None,\n            assigner.get_scale if assigner is not None else None)\n        model, optimizer, _, _ = ds_init(\n            args=args, model=model, model_parameters=optimizer_params, dist_init_required=not args.distributed,\n        )\n\n        print(\"model.gradient_accumulation_steps() = %d\" % model.gradient_accumulation_steps())\n        assert model.gradient_accumulation_steps() == args.update_freq\n    else:\n        if args.distributed:\n            model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)\n            model_without_ddp = model.module\n\n        optimizer = create_optimizer(\n            args, model_without_ddp, skip_list=skip_weight_decay_list,\n            get_num_layer=assigner.get_layer_id if assigner is not None else None, \n            get_layer_scale=assigner.get_scale if assigner is not None else None)\n        loss_scaler = NativeScaler()\n\n    print(\"Use step level LR scheduler!\")\n    lr_schedule_values = utils.cosine_scheduler(\n        args.lr, args.min_lr, args.epochs, num_training_steps_per_epoch,\n        warmup_epochs=args.warmup_epochs, warmup_steps=args.warmup_steps,\n    )\n    if args.weight_decay_end is None:\n        args.weight_decay_end = args.weight_decay\n    wd_schedule_values = utils.cosine_scheduler(\n        args.weight_decay, args.weight_decay_end, args.epochs, num_training_steps_per_epoch)\n    print(\"Max WD = %.7f, Min WD = %.7f\" % (max(wd_schedule_values), min(wd_schedule_values)))\n\n    if mixup_fn is not None:\n        # smoothing is handled with mixup label transform\n        criterion = SoftTargetCrossEntropy()\n    elif args.smoothing > 0.:\n        criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing)\n    else:\n        criterion = torch.nn.CrossEntropyLoss()\n\n    print(\"criterion = %s\" % str(criterion))\n\n    utils.auto_load_model(\n        args=args, model=model, model_without_ddp=model_without_ddp,\n        optimizer=optimizer, loss_scaler=loss_scaler, model_ema=model_ema)\n\n    if args.eval:\n        test_stats = evaluate(data_loader_val, model, device)\n        print(f\"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%\")\n        exit(0)\n\n    print(f\"Start training for {args.epochs} epochs\")\n    start_time = time.time()\n    max_accuracy = 0.0\n    for epoch in range(args.start_epoch, args.epochs):\n        if args.distributed:\n            data_loader_train.sampler.set_epoch(epoch)\n        if log_writer is not None:\n            log_writer.set_step(epoch * num_training_steps_per_epoch * args.update_freq)\n\n        train_stats = train_one_epoch(\n            model, criterion, data_loader_train, optimizer,\n            device, epoch, loss_scaler, args.clip_grad, model_ema, mixup_fn,\n            log_writer=log_writer, start_steps=epoch * num_training_steps_per_epoch,\n            lr_schedule_values=lr_schedule_values, wd_schedule_values=wd_schedule_values,\n            num_training_steps_per_epoch=num_training_steps_per_epoch, update_freq=args.update_freq,\n        )\n        if args.output_dir and args.save_ckpt:\n            if (epoch + 1) % args.save_ckpt_freq == 0 or epoch + 1 == args.epochs:\n                utils.save_model(\n                    args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,\n                    loss_scaler=loss_scaler, epoch=epoch, exp_name=args.exp_name, model_ema=model_ema)\n\n        if data_loader_val is not None:\n            test_stats = evaluate(data_loader_val, model, device)\n            print(f\"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%\")\n            if max_accuracy < test_stats[\"acc1\"]:\n                max_accuracy = test_stats[\"acc1\"]\n                if args.output_dir and args.save_ckpt:\n                    utils.save_model(\n                        args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,\n                        loss_scaler=loss_scaler, epoch=\"best\", model_ema=model_ema)\n\n            print(f'Max accuracy: {max_accuracy:.2f}%')\n            if log_writer is not None:\n                log_writer.update(test_acc1=test_stats['acc1'], head=\"perf\", step=epoch)\n                log_writer.update(test_acc5=test_stats['acc5'], head=\"perf\", step=epoch)\n                log_writer.update(test_loss=test_stats['loss'], head=\"perf\", step=epoch)\n\n            log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},\n                         **{f'test_{k}': v for k, v in test_stats.items()},\n                         'epoch': epoch,\n                         'n_parameters': n_parameters}\n        else:\n            log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},\n                         # **{f'test_{k}': v for k, v in test_stats.items()},\n                         'epoch': epoch,\n                         'n_parameters': n_parameters}\n\n        if args.output_dir and utils.is_main_process():\n            if log_writer is not None:\n                log_writer.flush()\n            with open(os.path.join(args.output_dir, \"log.txt\"), mode=\"a\", encoding=\"utf-8\") as f:\n                f.write(json.dumps(log_stats) + \"\\n\")\n\n\n    total_time = time.time() - start_time\n    total_time_str = str(datetime.timedelta(seconds=int(total_time)))\n    print('Training time {}'.format(total_time_str))\n\n\nif __name__ == '__main__':\n    opts, ds_init = get_args()\n    if opts.output_dir:\n        Path(opts.output_dir).mkdir(parents=True, exist_ok=True)\n    main(opts, ds_init)\n"
  },
  {
    "path": "tools/run_class_finetuning.py",
    "content": "import argparse\nimport datetime\nimport numpy as np\nimport time\nimport torch\nimport torch.backends.cudnn as cudnn\nimport json\nimport os\nimport shutil\n\nfrom pathlib import Path\n\nfrom timm.data.mixup import Mixup\nfrom timm.models import create_model\nfrom timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy\nfrom timm.utils import ModelEma\nfrom furnace.optim_factory import create_optimizer, get_parameter_groups, LayerDecayValueAssigner\n\nfrom furnace.datasets import build_dataset\nfrom furnace.engine_for_finetuning import train_one_epoch, evaluate\nfrom furnace.utils import NativeScalerWithGradNormCount as NativeScaler\nimport furnace.utils as utils\nfrom scipy import interpolate\nimport models.modeling_finetune\n\n\ndef get_args():\n    parser = argparse.ArgumentParser('fine-tuning and evaluation script for image classification', add_help=False)\n    parser.add_argument('--batch_size', default=64, type=int)\n    parser.add_argument('--epochs', default=30, type=int)\n    parser.add_argument('--update_freq', default=1, type=int)\n    parser.add_argument('--save_ckpt_freq', default=5, type=int)\n\n    # Model parameters\n    parser.add_argument('--model', default='deit_base_patch16_224', type=str, metavar='MODEL',\n                        help='Name of model to train')\n    parser.add_argument('--rel_pos_bias', action='store_true')\n    parser.add_argument('--disable_rel_pos_bias', action='store_false', dest='rel_pos_bias')\n    parser.set_defaults(rel_pos_bias=True)\n    parser.add_argument('--abs_pos_emb', action='store_true')\n    parser.set_defaults(abs_pos_emb=False)\n    parser.add_argument('--sin_pos_emb', action='store_true')\n    parser.set_defaults(sin_pos_emb=True)\n    parser.add_argument('--disable_sin_pos_emb', action='store_false', dest='sin_pos_emb')\n\n    parser.add_argument('--layer_scale_init_value', default=0.1, type=float, \n                        help=\"0.1 for base, 1e-5 for large. set 0 to disable layer scale\")\n\n    parser.add_argument('--input_size', default=224, type=int,\n                        help='images input size')\n\n    parser.add_argument('--drop', type=float, default=0.0, metavar='PCT',\n                        help='Dropout rate (default: 0.)')\n    parser.add_argument('--attn_drop_rate', type=float, default=0.0, metavar='PCT',\n                        help='Attention dropout rate (default: 0.)')\n    parser.add_argument('--drop_path', type=float, default=0.1, metavar='PCT',\n                        help='Drop path rate (default: 0.1)')\n\n    parser.add_argument('--disable_eval_during_finetuning', action='store_true', default=False)\n\n    parser.add_argument('--model_ema', action='store_true', default=False)\n    parser.add_argument('--model_ema_decay', type=float, default=0.9999, help='')\n    parser.add_argument('--model_ema_force_cpu', action='store_true', default=False, help='')\n\n    # Optimizer parameters\n    parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',\n                        help='Optimizer (default: \"adamw\"')\n    parser.add_argument('--opt_eps', default=1e-8, type=float, metavar='EPSILON',\n                        help='Optimizer Epsilon (default: 1e-8)')\n    parser.add_argument('--opt_betas', default=None, type=float, nargs='+', metavar='BETA',\n                        help='Optimizer Betas (default: None, use opt default)')\n    parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM',\n                        help='Clip gradient norm (default: None, no clipping)')\n    parser.add_argument('--momentum', type=float, default=0.9, metavar='M',\n                        help='SGD momentum (default: 0.9)')\n    parser.add_argument('--weight_decay', type=float, default=0.05,\n                        help='weight decay (default: 0.05)')\n    parser.add_argument('--weight_decay_end', type=float, default=None, help=\"\"\"Final value of the\n        weight decay. We use a cosine schedule for WD and using a larger decay by\n        the end of training improves performance for ViTs.\"\"\")\n\n    parser.add_argument('--lr', type=float, default=5e-4, metavar='LR',\n                        help='learning rate (default: 5e-4)')\n    parser.add_argument('--layer_decay', type=float, default=0.9)\n\n    parser.add_argument('--warmup_lr', type=float, default=1e-6, metavar='LR',\n                        help='warmup learning rate (default: 1e-6)')\n    parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR',\n                        help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')\n\n    parser.add_argument('--warmup_epochs', type=int, default=5, metavar='N',\n                        help='epochs to warmup LR, if scheduler supports')\n    parser.add_argument('--warmup_steps', type=int, default=-1, metavar='N',\n                        help='num of steps to warmup LR, will overload warmup_epochs if set > 0')\n\n    # Augmentation parameters\n    parser.add_argument('--color_jitter', type=float, default=0.4, metavar='PCT',\n                        help='Color jitter factor (default: 0.4)')\n    parser.add_argument('--aa', type=str, default='', metavar='NAME',\n                        help='Use AutoAugment policy. \"v0\" or \"original\". \" + \"(default: rand-m9-mstd0.5-inc1)'),\n    parser.add_argument('--smoothing', type=float, default=0.1,\n                        help='Label smoothing (default: 0.1)')\n    parser.add_argument('--train_interpolation', type=str, default='bicubic',\n                        help='Training interpolation (random, bilinear, bicubic default: \"bicubic\")')\n\n    # Evaluation parameters\n    parser.add_argument('--crop_pct', type=float, default=None)\n\n    # * Random Erase params\n    parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',\n                        help='Random erase prob (default: 0.25)')\n    parser.add_argument('--remode', type=str, default='pixel',\n                        help='Random erase mode (default: \"pixel\")')\n    parser.add_argument('--recount', type=int, default=1,\n                        help='Random erase count (default: 1)')\n    parser.add_argument('--resplit', action='store_true', default=False,\n                        help='Do not random erase first (clean) augmentation split')\n\n    # * Mixup params\n    parser.add_argument('--mixup', type=float, default=0,\n                        help='mixup alpha, mixup enabled if > 0.')\n    parser.add_argument('--cutmix', type=float, default=0,\n                        help='cutmix alpha, cutmix enabled if > 0.')\n    parser.add_argument('--cutmix_minmax', type=float, nargs='+', default=None,\n                        help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')\n    parser.add_argument('--mixup_prob', type=float, default=1.0,\n                        help='Probability of performing mixup or cutmix when either/both is enabled')\n    parser.add_argument('--mixup_switch_prob', type=float, default=0.5,\n                        help='Probability of switching to cutmix when both mixup and cutmix enabled')\n    parser.add_argument('--mixup_mode', type=str, default='batch',\n                        help='How to apply mixup/cutmix params. Per \"batch\", \"pair\", or \"elem\"')\n\n    # * Finetuning params\n    parser.add_argument('--finetune', default='',\n                        help='finetune from checkpoint')\n    parser.add_argument('--model_key', default='model|module|state_dict', type=str)\n    parser.add_argument('--model_prefix', default='', type=str)\n    parser.add_argument('--init_scale', default=0.001, type=float)\n    parser.add_argument('--use_mean_pooling', action='store_true')\n    parser.set_defaults(use_mean_pooling=True)\n    parser.add_argument('--use_cls', action='store_false', dest='use_mean_pooling')\n    parser.add_argument('--disable_weight_decay_on_rel_pos_bias', action='store_true', default=False)\n\n    # Dataset parameters\n    parser.add_argument('--data_path', default='/datasets01/imagenet_full_size/061417/', type=str,\n                        help='dataset path')\n    parser.add_argument('--eval_data_path', default=None, type=str,\n                        help='dataset path for evaluation')\n    parser.add_argument('--nb_classes', default=0, type=int,\n                        help='number of the classification types')\n    parser.add_argument('--imagenet_default_mean_and_std', default=False, action='store_true')\n\n    parser.add_argument('--data_set', default='IMNET', choices=['CIFAR', 'IMNET', 'IMNET100', 'image_folder'],\n                        type=str, help='ImageNet dataset path')\n    parser.add_argument('--output_dir', default='',\n                        help='path where to save, empty for no saving')\n    parser.add_argument('--log_dir', default=None,\n                        help='path where to tensorboard log')\n    parser.add_argument('--device', default='cuda',\n                        help='device to use for training / testing')\n    parser.add_argument('--seed', default=0, type=int)\n    parser.add_argument('--resume', default='',\n                        help='resume from checkpoint')\n    parser.add_argument('--auto_resume', action='store_true')\n    parser.add_argument('--no_auto_resume', action='store_false', dest='auto_resume')\n    parser.set_defaults(auto_resume=True)\n\n    parser.add_argument('--save_ckpt', action='store_true')\n    parser.add_argument('--no_save_ckpt', action='store_false', dest='save_ckpt')\n    parser.set_defaults(save_ckpt=True)\n\n    parser.add_argument('--start_epoch', default=0, type=int, metavar='N',\n                        help='start epoch')\n    parser.add_argument('--eval', action='store_true',\n                        help='Perform evaluation only')\n    parser.add_argument('--dist_eval', action='store_true', default=False,\n                        help='Enabling distributed evaluation')\n    parser.add_argument('--num_workers', default=10, type=int)\n    parser.add_argument('--pin_mem', action='store_true',\n                        help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')\n    parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')\n    parser.set_defaults(pin_mem=True)\n\n    # distributed training parameters\n    parser.add_argument('--world_size', default=1, type=int,\n                        help='number of distributed processes')\n    parser.add_argument('--local_rank', default=-1, type=int)\n    parser.add_argument('--dist_on_itp', action='store_true')\n    parser.add_argument('--dist_url', default='env://',\n                        help='url used to set up distributed training')\n\n    parser.add_argument('--enable_deepspeed', action='store_true', default=False)\n    parser.add_argument('--enable_linear_eval', action='store_true', default=False)\n    parser.add_argument('--enable_multi_print', action='store_true',default=False, help='allow each gpu prints something')\n\n    parser.add_argument('--exp_name', default='', type=str,\n                        help='name of exp. it is helpful when save the checkpoint')\n\n    known_args, _ = parser.parse_known_args()\n\n    if known_args.enable_deepspeed:\n        try:\n            import deepspeed\n            from deepspeed import DeepSpeedConfig\n            parser = deepspeed.add_config_arguments(parser)\n            ds_init = deepspeed.initialize\n        except:\n            print(\"Please 'pip install deepspeed==0.4.0'\")\n            exit(0)\n    else:\n        ds_init = None\n\n    return parser.parse_args(), ds_init\n\n\ndef main(args, ds_init):\n\n    if not args.enable_linear_eval:\n        args.aa = 'rand-m9-mstd0.5-inc1'\n\n    utils.init_distributed_mode(args)\n\n    if ds_init is not None:\n        utils.create_ds_config(args)\n\n    print(args)\n\n    device = torch.device(args.device)\n\n    # fix the seed for reproducibility\n    seed = args.seed + utils.get_rank()\n    torch.manual_seed(seed)\n    np.random.seed(seed)\n    # random.seed(seed)\n\n    cudnn.benchmark = True\n\n    dataset_train, args.nb_classes = build_dataset(is_train=True, args=args)\n    if args.disable_eval_during_finetuning:\n        dataset_val = None\n    else:\n        dataset_val, _ = build_dataset(is_train=False, args=args)\n\n    if True:  # args.distributed:\n        num_tasks = utils.get_world_size()\n        global_rank = utils.get_rank()\n        sampler_train = torch.utils.data.DistributedSampler(\n            dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True\n        )\n        print(\"Sampler_train = %s\" % str(sampler_train))\n        if args.dist_eval:\n            if len(dataset_val) % num_tasks != 0:\n                print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '\n                      'This will slightly alter validation results as extra duplicate entries are added to achieve '\n                      'equal num of samples per-process.')\n            sampler_val = torch.utils.data.DistributedSampler(\n                dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False)\n        else:\n            sampler_val = torch.utils.data.SequentialSampler(dataset_val)\n    else:\n        sampler_train = torch.utils.data.RandomSampler(dataset_train)\n        sampler_val = torch.utils.data.SequentialSampler(dataset_val)\n\n    if global_rank == 0 and args.log_dir is not None:\n        os.makedirs(args.log_dir, exist_ok=True)\n        log_writer = utils.TensorboardLogger(log_dir=args.log_dir)\n    else:\n        log_writer = None\n\n    data_loader_train = torch.utils.data.DataLoader(\n        dataset_train, sampler=sampler_train,\n        batch_size=args.batch_size,\n        num_workers=args.num_workers,\n        pin_memory=args.pin_mem,\n        drop_last=True,\n    )\n\n    if dataset_val is not None:\n        data_loader_val = torch.utils.data.DataLoader(\n            dataset_val, sampler=sampler_val,\n            batch_size=int(1.5 * args.batch_size),\n            num_workers=args.num_workers,\n            pin_memory=args.pin_mem,\n            drop_last=False\n        )\n    else:\n        data_loader_val = None\n\n    mixup_fn = None\n    mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None\n    if mixup_active:\n        print(\"Mixup is activated!\")\n        mixup_fn = Mixup(\n            mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,\n            prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,\n            label_smoothing=args.smoothing, num_classes=args.nb_classes)\n\n    model = create_model(\n        args.model,\n        pretrained=False,\n        num_classes=args.nb_classes,\n        drop_rate=args.drop,\n        drop_path_rate=args.drop_path,\n        attn_drop_rate=args.attn_drop_rate,\n        drop_block_rate=None,\n        use_mean_pooling=args.use_mean_pooling,\n        init_scale=args.init_scale,\n        use_rel_pos_bias=args.rel_pos_bias,\n        use_abs_pos_emb=args.abs_pos_emb,\n        init_values=args.layer_scale_init_value,\n        lin_probe=args.enable_linear_eval,\n        args=args,\n    )\n\n    if args.enable_linear_eval:\n        # freeze all layers but the last fc\n        linear_keyword = 'head'\n        head_norm = 'fc_norm'\n        requires_grad = []\n        for name, param in model.named_parameters():\n            if name not in ['%s.weight' % linear_keyword, '%s.bias' % linear_keyword] and head_norm not in name:\n                param.requires_grad = False\n            else:\n                requires_grad.append(name)\n        print(f'require grad parameter: ', requires_grad)\n        # init the fc layer\n        getattr(model, linear_keyword).weight.data.normal_(mean=0.0, std=0.01)\n        getattr(model, linear_keyword).bias.data.zero_()\n\n    patch_size = model.patch_embed.patch_size\n    print(\"Patch size = %s\" % str(patch_size))\n    args.window_size = (args.input_size // patch_size[0], args.input_size // patch_size[1])\n    args.patch_size = patch_size\n\n    if args.finetune:\n        if args.finetune.startswith('https'):\n            checkpoint = torch.hub.load_state_dict_from_url(\n                args.finetune, map_location='cpu', check_hash=True)\n        else:\n            checkpoint = torch.load(args.finetune, map_location='cpu')\n\n        print(\"Load ckpt from %s\" % args.finetune)\n        checkpoint_model = None\n        for model_key in args.model_key.split('|'):\n            if model_key in checkpoint:\n                checkpoint_model = checkpoint[model_key]\n                print(\"Load state_dict by model_key = %s\" % model_key)\n                break\n        if checkpoint_model is None:\n            checkpoint_model = checkpoint\n        state_dict = model.state_dict()\n        all_keys = list(checkpoint_model.keys())\n        print(\"##########origin keys:\", len(all_keys), all_keys)\n        # NOTE: remove all decoder keys\n        all_keys = [key for key in all_keys if key.startswith('encoder.')]\n        print(\"all keys:\", all_keys)\n        for key in all_keys:\n            new_key = key.replace('encoder.','')\n            # print(\"new_key:\", new_key)\n            checkpoint_model[new_key] = checkpoint_model[key]\n            checkpoint_model.pop(key)\n            \n        for key in list(checkpoint_model.keys()):\n            if key.startswith('decoder.'):\n                # print(\"key:\", key)\n                checkpoint_model.pop(key)\n            if key.startswith('teacher.'):\n                # print(\"key:\", key)\n                checkpoint_model.pop(key)\n\n        # NOTE: replace norm with fc_norm\n        for key in list(checkpoint_model.keys()):\n            # print(\"new key:\", key)\n            if key.startswith('norm.'):\n                new_key = key.replace('norm.','fc_norm.')\n                checkpoint_model[new_key] = checkpoint_model[key]\n                checkpoint_model.pop(key)\n\n        for k in ['head.weight', 'head.bias']:\n            if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:\n                print(f\"Removing key {k} from pretrained checkpoint\")\n                del checkpoint_model[k]\n                \n        if model.use_rel_pos_bias and \"rel_pos_bias.relative_position_bias_table\" in checkpoint_model:\n            print(\"Expand the shared relative position embedding to each transformer block. \")\n            num_layers = model.get_num_layers()\n            rel_pos_bias = checkpoint_model[\"rel_pos_bias.relative_position_bias_table\"]\n            for i in range(num_layers):\n                checkpoint_model[\"blocks.%d.attn.relative_position_bias_table\" % i] = rel_pos_bias.clone()\n\n            checkpoint_model.pop(\"rel_pos_bias.relative_position_bias_table\")\n\n        all_keys = list(checkpoint_model.keys())\n\n        for key in all_keys:\n            if \"relative_position_index\" in key:\n                checkpoint_model.pop(key)\n\n            if \"relative_position_bias_table\" in key and args.rel_pos_bias:\n                rel_pos_bias = checkpoint_model[key]\n                src_num_pos, num_attn_heads = rel_pos_bias.size()\n                dst_num_pos, _ = model.state_dict()[key].size()\n                dst_patch_shape = model.patch_embed.patch_shape\n                if dst_patch_shape[0] != dst_patch_shape[1]:\n                    raise NotImplementedError()\n                num_extra_tokens = dst_num_pos - (dst_patch_shape[0] * 2 - 1) * (dst_patch_shape[1] * 2 - 1)\n                src_size = int((src_num_pos - num_extra_tokens) ** 0.5)\n                dst_size = int((dst_num_pos - num_extra_tokens) ** 0.5)\n                if src_size != dst_size:\n                    print(\"Position interpolate for %s from %dx%d to %dx%d\" % (\n                        key, src_size, src_size, dst_size, dst_size))\n                    extra_tokens = rel_pos_bias[-num_extra_tokens:, :]\n                    rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :]\n\n                    def geometric_progression(a, r, n):\n                        return a * (1.0 - r ** n) / (1.0 - r)\n\n                    left, right = 1.01, 1.5\n                    while right - left > 1e-6:\n                        q = (left + right) / 2.0\n                        gp = geometric_progression(1, q, src_size // 2)\n                        if gp > dst_size // 2:\n                            right = q\n                        else:\n                            left = q\n\n                    # if q > 1.090307:\n                    #     q = 1.090307\n\n                    dis = []\n                    cur = 1\n                    for i in range(src_size // 2):\n                        dis.append(cur)\n                        cur += q ** (i + 1)\n\n                    r_ids = [-_ for _ in reversed(dis)]\n\n                    x = r_ids + [0] + dis\n                    y = r_ids + [0] + dis\n\n                    t = dst_size // 2.0\n                    dx = np.arange(-t, t + 0.1, 1.0)\n                    dy = np.arange(-t, t + 0.1, 1.0)\n\n                    print(\"Original positions = %s\" % str(x))\n                    print(\"Target positions = %s\" % str(dx))\n\n                    all_rel_pos_bias = []\n\n                    for i in range(num_attn_heads):\n                        z = rel_pos_bias[:, i].view(src_size, src_size).float().numpy()\n                        f = interpolate.interp2d(x, y, z, kind='cubic')\n                        all_rel_pos_bias.append(\n                            torch.Tensor(f(dx, dy)).contiguous().view(-1, 1).to(rel_pos_bias.device))\n\n                    rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1)\n\n                    new_rel_pos_bias = torch.cat((rel_pos_bias, extra_tokens), dim=0)\n                    checkpoint_model[key] = new_rel_pos_bias\n\n        print(\"##############new keys:\", len(checkpoint_model), checkpoint_model.keys())\n        #print(\"##############model:\", model)\n\n        # interpolate position embedding\n        if 'pos_embed' in checkpoint_model and args.abs_pos_emb:\n            pos_embed_checkpoint = checkpoint_model['pos_embed']\n            embedding_size = pos_embed_checkpoint.shape[-1]\n            num_patches = model.patch_embed.num_patches\n            num_extra_tokens = model.pos_embed.shape[-2] - num_patches\n            # height (== width) for the checkpoint position embedding\n            orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)\n            # height (== width) for the new position embedding\n            new_size = int(num_patches ** 0.5)\n            # class_token and dist_token are kept unchanged\n            if orig_size != new_size:\n                print(\"Position interpolate from %dx%d to %dx%d\" % (orig_size, orig_size, new_size, new_size))\n                extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]\n                # only the position tokens are interpolated\n                pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]\n                pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)\n                pos_tokens = torch.nn.functional.interpolate(\n                    pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)\n                pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)\n                new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)\n                checkpoint_model['pos_embed'] = new_pos_embed\n\n        utils.load_state_dict(model, checkpoint_model, prefix=args.model_prefix)\n        # model.load_state_dict(checkpoint_model, strict=False)\n\n    model.to(device)\n\n    model_ema = None\n    if args.model_ema:\n        # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper\n        model_ema = ModelEma(\n            model,\n            decay=args.model_ema_decay,\n            device='cpu' if args.model_ema_force_cpu else '',\n            resume='')\n        print(\"Using EMA with decay = %.8f\" % args.model_ema_decay)\n\n    model_without_ddp = model\n    n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)\n\n    print(\"Model = %s\" % str(model_without_ddp))\n    print('number of params:', n_parameters)\n\n    total_batch_size = args.batch_size * args.update_freq * utils.get_world_size()\n    num_training_steps_per_epoch = len(dataset_train) // total_batch_size\n    \n    print(\"LR = %.8f\" % args.lr)\n    print(\"Batch size = %d\" % total_batch_size)\n    print(\"Update frequent = %d\" % args.update_freq)\n    print(\"Number of training examples = %d\" % len(dataset_train))\n    print(\"Number of training training per epoch = %d\" % num_training_steps_per_epoch)\n\n    num_layers = model_without_ddp.get_num_layers()\n    if args.layer_decay < 1.0:\n        assigner = LayerDecayValueAssigner(list(args.layer_decay ** (num_layers + 1 - i) for i in range(num_layers + 2)))\n    else:\n        assigner = None\n\n    if assigner is not None:\n        print(\"Assigned values = %s\" % str(assigner.values))\n\n    skip_weight_decay_list = model.no_weight_decay()\n    print(\"Skip weight decay list: \", skip_weight_decay_list)\n\n    if args.disable_weight_decay_on_rel_pos_bias:\n        for i in range(num_layers):\n            skip_weight_decay_list.add(\"blocks.%d.attn.relative_position_bias_table\" % i)\n\n    if args.enable_deepspeed:\n        loss_scaler = None\n        optimizer_params = get_parameter_groups(\n            model, args.weight_decay, skip_weight_decay_list,\n            assigner.get_layer_id if assigner is not None else None,\n            assigner.get_scale if assigner is not None else None)\n        model, optimizer, _, _ = ds_init(\n            args=args, model=model, model_parameters=optimizer_params, dist_init_required=not args.distributed,\n        )\n\n        print(\"model.gradient_accumulation_steps() = %d\" % model.gradient_accumulation_steps())\n        assert model.gradient_accumulation_steps() == args.update_freq\n    else:\n        if args.distributed:\n            model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)\n            model_without_ddp = model.module\n\n        optimizer = create_optimizer(\n            args, model_without_ddp, skip_list=skip_weight_decay_list,\n            get_num_layer=assigner.get_layer_id if assigner is not None else None, \n            get_layer_scale=assigner.get_scale if assigner is not None else None)\n        loss_scaler = NativeScaler()\n\n    print(\"Use step level LR scheduler!\")\n    lr_schedule_values = utils.cosine_scheduler(\n        args.lr, args.min_lr, args.epochs, num_training_steps_per_epoch,\n        warmup_epochs=args.warmup_epochs, warmup_steps=args.warmup_steps,\n    )\n    if args.weight_decay_end is None:\n        args.weight_decay_end = args.weight_decay\n    wd_schedule_values = utils.cosine_scheduler(\n        args.weight_decay, args.weight_decay_end, args.epochs, num_training_steps_per_epoch)\n    print(\"Max WD = %.7f, Min WD = %.7f\" % (max(wd_schedule_values), min(wd_schedule_values)))\n\n    if mixup_fn is not None:\n        # smoothing is handled with mixup label transform\n        criterion = SoftTargetCrossEntropy()\n    elif args.smoothing > 0.:\n        criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing)\n    else:\n        criterion = torch.nn.CrossEntropyLoss()\n\n    print(\"criterion = %s\" % str(criterion))\n\n    utils.auto_load_model(\n        args=args, model=model, model_without_ddp=model_without_ddp,\n        optimizer=optimizer, loss_scaler=loss_scaler, model_ema=model_ema)\n\n    if args.eval:\n        test_stats = evaluate(data_loader_val, model, device)\n        print(f\"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%\")\n        exit(0)\n\n    print(f\"Start training for {args.epochs} epochs\")\n    start_time = time.time()\n    max_accuracy = 0.0\n    for epoch in range(args.start_epoch, args.epochs):\n        if args.distributed:\n            data_loader_train.sampler.set_epoch(epoch)\n        if log_writer is not None:\n            log_writer.set_step(epoch * num_training_steps_per_epoch * args.update_freq)\n\n        train_stats = train_one_epoch(\n            model, criterion, data_loader_train, optimizer,\n            device, epoch, loss_scaler, args.clip_grad, model_ema, mixup_fn,\n            log_writer=log_writer, start_steps=epoch * num_training_steps_per_epoch,\n            lr_schedule_values=lr_schedule_values, wd_schedule_values=wd_schedule_values,\n            num_training_steps_per_epoch=num_training_steps_per_epoch, update_freq=args.update_freq,\n        )\n        if args.output_dir and args.save_ckpt:\n            if (epoch + 1) % args.save_ckpt_freq == 0 or epoch + 1 == args.epochs:\n                utils.save_model(\n                    args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,\n                    loss_scaler=loss_scaler, epoch=epoch, model_ema=model_ema)\n        if data_loader_val is not None:\n            test_stats = evaluate(data_loader_val, model, device)\n            print(f\"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%\")\n            if max_accuracy < test_stats[\"acc1\"]:\n                max_accuracy = test_stats[\"acc1\"]\n                if args.output_dir and args.save_ckpt:\n                    utils.save_model(\n                        args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,\n                        loss_scaler=loss_scaler, epoch=\"best\", model_ema=model_ema)\n\n            print(f'Max accuracy: {max_accuracy:.2f}%')\n            if log_writer is not None:\n                log_writer.update(test_acc1=test_stats['acc1'], head=\"perf\", step=epoch)\n                log_writer.update(test_acc5=test_stats['acc5'], head=\"perf\", step=epoch)\n                log_writer.update(test_loss=test_stats['loss'], head=\"perf\", step=epoch)\n\n            log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},\n                         **{f'test_{k}': v for k, v in test_stats.items()},\n                         'epoch': epoch,\n                         'n_parameters': n_parameters}\n        else:\n            log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},\n                         # **{f'test_{k}': v for k, v in test_stats.items()},\n                         'epoch': epoch,\n                         'n_parameters': n_parameters}\n\n        if args.output_dir and utils.is_main_process():\n            if log_writer is not None:\n                log_writer.flush()\n            with open(os.path.join(args.output_dir, \"log.txt\"), mode=\"a\", encoding=\"utf-8\") as f:\n                f.write(json.dumps(log_stats) + \"\\n\")\n\n    total_time = time.time() - start_time\n    total_time_str = str(datetime.timedelta(seconds=int(total_time)))\n    print('Training time {}'.format(total_time_str))\n\n\nif __name__ == '__main__':\n    opts, ds_init = get_args()\n    if opts.output_dir:\n        Path(opts.output_dir).mkdir(parents=True, exist_ok=True)\n    main(opts, ds_init)\n"
  },
  {
    "path": "tools/run_linear.py",
    "content": "import argparse\nimport datetime\nimport json\nimport numpy as np\nimport os\nimport time\nfrom pathlib import Path\n\nimport torch\nimport torch.backends.cudnn as cudnn\n\nimport torchvision.transforms as transforms\nimport torchvision.datasets as datasets\n\nimport timm\n\nassert timm.__version__ == \"0.3.2\" # version check\nfrom timm.models.layers import trunc_normal_\n\nimport linear_util.misc as misc\nfrom linear_util.pos_embed import interpolate_pos_embed\nfrom linear_util.misc import NativeScalerWithGradNormCount as NativeScaler\nfrom linear_util.lars import LARS\nfrom linear_util.crop import RandomResizedCrop\n\nimport models.modeling_finetune as models_vit\n\nfrom linear_util.engine_finetune import train_one_epoch, evaluate\n\n\ndef setup_for_distributed(rank):\n    \"\"\"\n    This function disables printing when not in master process\n    \"\"\"\n    import builtins as __builtin__\n    builtin_print = __builtin__.print\n\n    def print(*args, **kwargs):\n        if rank==0:\n            builtin_print(*args, **kwargs)\n\n    __builtin__.print = print\n\ndef get_args_parser():\n    parser = argparse.ArgumentParser('MAE linear probing for image classification', add_help=False)\n    parser.add_argument('--batch_size', default=512, type=int,\n                        help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus')\n    parser.add_argument('--epochs', default=90, type=int)\n    parser.add_argument('--accum_iter', default=1, type=int,\n                        help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)')\n\n    # Model parameters\n    parser.add_argument('--model', default='vit_large_patch16', type=str, metavar='MODEL',\n                        help='Name of model to train')\n\n    # Optimizer parameters\n    parser.add_argument('--weight_decay', type=float, default=0,\n                        help='weight decay (default: 0 for linear probe following MoCo v1)')\n\n    parser.add_argument('--lr', type=float, default=None, metavar='LR',\n                        help='learning rate (absolute lr)')\n    parser.add_argument('--blr', type=float, default=0.1, metavar='LR',\n                        help='base learning rate: absolute_lr = base_lr * total_batch_size / 256')\n\n    parser.add_argument('--min_lr', type=float, default=0., metavar='LR',\n                        help='lower lr bound for cyclic schedulers that hit 0')\n\n    parser.add_argument('--warmup_epochs', type=int, default=10, metavar='N',\n                        help='epochs to warmup LR')\n\n    # * Finetuning params\n    parser.add_argument('--finetune', default='',\n                        help='finetune from checkpoint')\n                        \n    # Dataset parameters\n    parser.add_argument('--data_path', default='/datasets01/imagenet_full_size/061417/', type=str,\n                        help='dataset path')\n    parser.add_argument('--nb_classes', default=1000, type=int,\n                        help='number of the classification types')\n\n    parser.add_argument('--output_dir', default='./output_dir',\n                        help='path where to save, empty for no saving')\n    parser.add_argument('--log_dir', default='./output_dir',\n                        help='path where to tensorboard log')\n    parser.add_argument('--device', default='cuda',\n                        help='device to use for training / testing')\n    parser.add_argument('--seed', default=0, type=int)\n    parser.add_argument('--resume', default='',\n                        help='resume from checkpoint')\n\n    parser.add_argument('--start_epoch', default=0, type=int, metavar='N',\n                        help='start epoch')\n    parser.add_argument('--eval', action='store_true',\n                        help='Perform evaluation only')\n    parser.add_argument('--dist_eval', action='store_true', default=False,\n                        help='Enabling distributed evaluation (recommended during training for faster monitor')\n    parser.add_argument('--num_workers', default=10, type=int)\n    parser.add_argument('--pin_mem', action='store_true',\n                        help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')\n    parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')\n    parser.set_defaults(pin_mem=True)\n\n    # distributed training parameters\n    parser.add_argument('--world_size', default=1, type=int,\n                        help='number of distributed processes')\n    parser.add_argument('--local_rank', default=-1, type=int)\n    parser.add_argument('--dist_on_itp', action='store_true')\n    parser.add_argument('--dist_url', default='env://',\n                        help='url used to set up distributed training')\n\n\n    parser.add_argument('--drop', type=float, default=0.0, metavar='PCT',\n                        help='Dropout rate (default: 0.)')\n    parser.add_argument('--attn_drop_rate', type=float, default=0.0, metavar='PCT',\n                        help='Attention dropout rate (default: 0.)')\n    parser.add_argument('--drop_path', type=float, default=0, metavar='PCT',\n                        help='Drop path rate (default: 0.1)')\n    parser.add_argument('--init_scale', default=0.001, type=float)\n    parser.add_argument('--use_mean_pooling', action='store_true')\n    parser.set_defaults(use_mean_pooling=True)\n    parser.add_argument('--use_cls', action='store_false', dest='use_mean_pooling')\n    parser.add_argument('--rel_pos_bias', action='store_true')\n    parser.add_argument('--disable_rel_pos_bias', action='store_false', dest='rel_pos_bias')\n    parser.set_defaults(rel_pos_bias=True)\n    parser.add_argument('--abs_pos_emb', action='store_true')\n    parser.set_defaults(abs_pos_emb=False)\n    parser.add_argument('--sin_pos_emb', action='store_true')\n    parser.set_defaults(sin_pos_emb=True)\n    parser.add_argument('--disable_sin_pos_emb', action='store_false', dest='sin_pos_emb')\n    parser.add_argument('--layer_scale_init_value', default=0.1, type=float, \n                        help=\"0.1 for base, 1e-5 for large. set 0 to disable layer scale\")\n    parser.add_argument('--enable_linear_eval', action='store_true', default=False)\n\n    parser.add_argument('--exp_name', default='', type=str,\n                        help='name of exp. it is helpful when save the checkpoint')\n\n    parser.add_argument('--save_freq', default=50, type=int,\n                        help='freq of saving models')\n    parser.add_argument('--linear_type', default='standard', type=str,\n                        help='standard or attentive')\n    return parser\n\n\ndef main(args):\n    misc.init_distributed_mode(args)\n\n    setup_for_distributed(args.local_rank)\n\n    print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))\n    print(\"{}\".format(args).replace(', ', ',\\n'))\n\n    device = torch.device(args.device)\n\n    # fix the seed for reproducibility\n    seed = args.seed + misc.get_rank()\n    torch.manual_seed(seed)\n    np.random.seed(seed)\n\n    cudnn.benchmark = True\n\n    # linear probe: weak augmentation\n    transform_train = transforms.Compose([\n            RandomResizedCrop(224, interpolation=3),\n            transforms.RandomHorizontalFlip(),\n            transforms.ToTensor(),\n            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])\n    transform_val = transforms.Compose([\n            transforms.Resize(256, interpolation=3),\n            transforms.CenterCrop(224),\n            transforms.ToTensor(),\n            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])\n    dataset_train = datasets.ImageFolder(os.path.join(args.data_path, 'train'), transform=transform_train)\n    dataset_val = datasets.ImageFolder(os.path.join(args.data_path, 'val'), transform=transform_val)\n    print(dataset_train)\n    print(dataset_val)\n\n    if True:  # args.distributed:\n        num_tasks = misc.get_world_size()\n        global_rank = misc.get_rank()\n        sampler_train = torch.utils.data.DistributedSampler(\n            dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True\n        )\n        print(\"Sampler_train = %s\" % str(sampler_train))\n        if args.dist_eval:\n            if len(dataset_val) % num_tasks != 0:\n                print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '\n                      'This will slightly alter validation results as extra duplicate entries are added to achieve '\n                      'equal num of samples per-process.')\n            sampler_val = torch.utils.data.DistributedSampler(\n                dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=True)  # shuffle=True to reduce monitor bias\n        else:\n            sampler_val = torch.utils.data.SequentialSampler(dataset_val)\n    else:\n        sampler_train = torch.utils.data.RandomSampler(dataset_train)\n        sampler_val = torch.utils.data.SequentialSampler(dataset_val)\n\n    if global_rank == 0 and args.log_dir is not None and not args.eval:\n        os.makedirs(args.log_dir, exist_ok=True)\n        # log_writer = SummaryWriter(log_dir=args.log_dir)\n        log_writer = None\n    else:\n        log_writer = None\n\n    data_loader_train = torch.utils.data.DataLoader(\n        dataset_train, sampler=sampler_train,\n        batch_size=args.batch_size,\n        num_workers=args.num_workers,\n        pin_memory=args.pin_mem,\n        drop_last=True,\n    )\n\n    data_loader_val = torch.utils.data.DataLoader(\n        dataset_val, sampler=sampler_val,\n        batch_size=args.batch_size,\n        num_workers=args.num_workers,\n        pin_memory=args.pin_mem,\n        drop_last=False\n    )\n\n    model = models_vit.__dict__[args.model](\n        num_classes=args.nb_classes,\n        drop_rate=args.drop,\n        drop_path_rate=args.drop_path,\n        attn_drop_rate=args.attn_drop_rate,\n        use_mean_pooling=args.use_mean_pooling,\n        init_scale=args.init_scale,\n        use_rel_pos_bias=args.rel_pos_bias,\n        use_abs_pos_emb=args.abs_pos_emb,\n        init_values=args.layer_scale_init_value,\n        lin_probe=args.enable_linear_eval,\n        args=args,\n    )\n\n    if args.finetune and not args.eval:\n        checkpoint = torch.load(args.finetune, map_location='cpu')\n\n        print(\"Load pre-trained checkpoint from: %s\" % args.finetune)\n        checkpoint_model = checkpoint['model']\n        state_dict = model.state_dict()\n        for k in ['head.weight', 'head.bias']:\n            if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:\n                print(f\"Removing key {k} from pretrained checkpoint\")\n                del checkpoint_model[k]\n\n        for key in list(checkpoint_model.keys()):\n            if 'encoder.' in key:\n                new_key = key.replace('encoder.','')\n                checkpoint_model[new_key] = checkpoint_model[key]\n                checkpoint_model.pop(key)\n            if 'teacher' in key or 'decoder' in key:\n                checkpoint_model.pop(key)\n\n        if args.rel_pos_bias and \"rel_pos_bias.relative_position_bias_table\" in checkpoint_model:\n            print(\"Expand the shared relative position embedding to each transformer block. \")\n            num_layers = model.get_num_layers()\n            rel_pos_bias = checkpoint_model[\"rel_pos_bias.relative_position_bias_table\"]\n            for i in range(num_layers):\n                checkpoint_model[\"blocks.%d.attn.relative_position_bias_table\" % i] = rel_pos_bias.clone()\n\n            checkpoint_model.pop(\"rel_pos_bias.relative_position_bias_table\")     \n\n        all_keys = list(checkpoint_model.keys())\n\n        for key in all_keys:\n            if \"relative_position_index\" in key:\n                checkpoint_model.pop(key)\n\n            if \"relative_position_bias_table\" in key and args.rel_pos_bias:\n                rel_pos_bias = checkpoint_model[key]\n                src_num_pos, num_attn_heads = rel_pos_bias.size()\n                dst_num_pos, _ = model.state_dict()[key].size()\n                dst_patch_shape = model.patch_embed.patch_shape\n                if dst_patch_shape[0] != dst_patch_shape[1]:\n                    raise NotImplementedError()\n                num_extra_tokens = dst_num_pos - (dst_patch_shape[0] * 2 - 1) * (dst_patch_shape[1] * 2 - 1)\n                src_size = int((src_num_pos - num_extra_tokens) ** 0.5)\n                dst_size = int((dst_num_pos - num_extra_tokens) ** 0.5)\n                if src_size != dst_size:\n                    print(\"Position interpolate for %s from %dx%d to %dx%d\" % (\n                        key, src_size, src_size, dst_size, dst_size))\n                    extra_tokens = rel_pos_bias[-num_extra_tokens:, :]\n                    rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :]\n\n                    def geometric_progression(a, r, n):\n                        return a * (1.0 - r ** n) / (1.0 - r)\n\n                    left, right = 1.01, 1.5\n                    while right - left > 1e-6:\n                        q = (left + right) / 2.0\n                        gp = geometric_progression(1, q, src_size // 2)\n                        if gp > dst_size // 2:\n                            right = q\n                        else:\n                            left = q\n\n                    # if q > 1.090307:\n                    #     q = 1.090307\n\n                    dis = []\n                    cur = 1\n                    for i in range(src_size // 2):\n                        dis.append(cur)\n                        cur += q ** (i + 1)\n\n                    r_ids = [-_ for _ in reversed(dis)]\n\n                    x = r_ids + [0] + dis\n                    y = r_ids + [0] + dis\n\n                    t = dst_size // 2.0\n                    dx = np.arange(-t, t + 0.1, 1.0)\n                    dy = np.arange(-t, t + 0.1, 1.0)\n\n                    print(\"Original positions = %s\" % str(x))\n                    print(\"Target positions = %s\" % str(dx))\n\n                    all_rel_pos_bias = []\n\n                    for i in range(num_attn_heads):\n                        z = rel_pos_bias[:, i].view(src_size, src_size).float().numpy()\n                        f = interpolate.interp2d(x, y, z, kind='cubic')\n                        all_rel_pos_bias.append(\n                            torch.Tensor(f(dx, dy)).contiguous().view(-1, 1).to(rel_pos_bias.device))\n\n                    rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1)\n\n                    new_rel_pos_bias = torch.cat((rel_pos_bias, extra_tokens), dim=0)\n                    checkpoint_model[key] = new_rel_pos_bias\n\n        # interpolate position embedding\n        interpolate_pos_embed(model, checkpoint_model)\n\n        # load pre-trained model\n        msg = model.load_state_dict(checkpoint_model, strict=False)\n        print(msg)\n\n        trunc_normal_(model.head.weight, std=0.01)\n\n\n    model.head = torch.nn.Sequential(torch.nn.BatchNorm1d(model.head.in_features, affine=False, eps=1e-6), model.head)\n\n    requires_grad = []\n\n    for _, p in model.named_parameters():\n        p.requires_grad = False\n    for nname, p in model.head.named_parameters():\n        p.requires_grad = True\n        requires_grad.append(nname)\n    \n    print(f'require grad parameter: ', requires_grad)\n\n    model.to(device)\n\n    model_without_ddp = model\n    n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)\n\n    print(\"Model = %s\" % str(model_without_ddp))\n    print('number of params (M): %.2f' % (n_parameters / 1.e6))\n\n    eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size()\n    \n    if args.lr is None:  # only base_lr is specified\n        args.lr = args.blr * eff_batch_size / 256\n\n    print(\"base lr: %.2e\" % (args.lr * 256 / eff_batch_size))\n    print(\"actual lr: %.2e\" % args.lr)\n\n    print(\"accumulate grad iterations: %d\" % args.accum_iter)\n    print(\"effective batch size: %d\" % eff_batch_size)\n\n    if args.distributed:\n        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])\n        model_without_ddp = model.module\n\n    optimizer = LARS(model_without_ddp.head.parameters(), lr=args.lr, weight_decay=args.weight_decay)\n    print(optimizer)\n    loss_scaler = NativeScaler()\n\n    criterion = torch.nn.CrossEntropyLoss()\n\n    print(\"criterion = %s\" % str(criterion))\n\n    misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler)\n\n    if args.eval:\n        test_stats = evaluate(data_loader_val, model, device)\n        print(f\"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%\")\n        exit(0)\n\n    print(f\"Start training for {args.epochs} epochs\")\n    start_time = time.time()\n    max_accuracy = 0.0\n    for epoch in range(args.start_epoch, args.epochs):\n        if args.distributed:\n            data_loader_train.sampler.set_epoch(epoch)\n                \n        train_stats = train_one_epoch(\n            model, criterion, data_loader_train,\n            optimizer, device, epoch, loss_scaler,\n            max_norm=None,\n            log_writer=log_writer,\n            args=args\n        )\n        if args.output_dir and (epoch % args.save_freq == 0 or epoch + 1 == args.epochs):\n            misc.save_model(\n                args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,\n                loss_scaler=loss_scaler, epoch=epoch)\n\n        test_stats = evaluate(data_loader_val, model, device)\n        print(f\"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%\")\n        max_accuracy = max(max_accuracy, test_stats[\"acc1\"])\n        print(f'Max accuracy: {max_accuracy:.2f}%')\n\n        if log_writer is not None:\n            log_writer.add_scalar('perf/test_acc1', test_stats['acc1'], epoch)\n            log_writer.add_scalar('perf/test_acc5', test_stats['acc5'], epoch)\n            log_writer.add_scalar('perf/test_loss', test_stats['loss'], epoch)\n\n        log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},\n                        **{f'test_{k}': v for k, v in test_stats.items()},\n                        'epoch': epoch,\n                        'n_parameters': n_parameters}\n\n        if args.output_dir and misc.is_main_process():\n            if log_writer is not None:\n                log_writer.flush()\n            with open(os.path.join(args.output_dir, \"log.txt\"), mode=\"a\", encoding=\"utf-8\") as f:\n                f.write(json.dumps(log_stats) + \"\\n\")\n\n    total_time = time.time() - start_time\n    total_time_str = str(datetime.timedelta(seconds=int(total_time)))\n    print('Training time {}'.format(total_time_str))\n\n\nif __name__ == '__main__':\n    args = get_args_parser()\n    args = args.parse_args()\n    if args.output_dir:\n        Path(args.output_dir).mkdir(parents=True, exist_ok=True)\n    main(args)"
  },
  {
    "path": "tools/run_pretraining.py",
    "content": "import argparse\nimport datetime\nimport numpy as np\nimport time\nimport torch\nimport torch.backends.cudnn as cudnn\nimport json\nimport os\nimport shutil\n\nfrom pathlib import Path\n\nfrom timm.models import create_model\nfrom furnace.optim_factory import create_optimizer\n\nfrom furnace.datasets import build_cae_pretraining_dataset\nfrom furnace.engine_for_pretraining import train_one_epoch\nfrom furnace.utils import NativeScalerWithGradNormCount as NativeScaler\nimport furnace.utils as utils\nfrom models import modeling_cae\nimport torch.distributed as dist\n\ndef get_args():\n    parser = argparse.ArgumentParser('pre-training script', add_help=False)\n    parser.add_argument('--batch_size', default=64, type=int)\n    parser.add_argument('--epochs', default=300, type=int)\n    parser.add_argument('--save_ckpt_freq', default=50, type=int)\n    parser.add_argument(\"--discrete_vae_weight_path\", type=str)\n    parser.add_argument(\"--discrete_vae_type\", type=str, default=\"dall-e\", help='[dall-e, vqgan_gumbel_f8_8192, customized]')\n    parser.add_argument('--dvae_num_layers', default=3, type=int)\n\n    # Model parameters\n    parser.add_argument('--model', default='deit_base_patch16_224', type=str, metavar='MODEL',\n                        help='Name of model to train')\n    parser.add_argument('--rel_pos_bias', action='store_true', default=False)\n    parser.add_argument('--abs_pos_emb', action='store_true', default=False)\n    parser.add_argument('--sincos_pos_emb', action='store_true', default=False)\n    parser.add_argument('--layer_scale_init_value', default=0.1, type=float, \n                        help=\"0.1 for base, 1e-5 for large. set 0 to disable layer scale\")\n\n    parser.add_argument('--input_size', default=224, type=int,\n                        help='images input size for backbone')\n    parser.add_argument('--second_input_size', default=112, type=int,\n                        help='images input size for discrete vae')\n\n    parser.add_argument('--drop_path', type=float, default=0, metavar='PCT',\n                        help='Drop path rate (default: 0)')\n\n    # Optimizer parameters\n    parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',\n                        help='Optimizer (default: \"adamw\"')\n    parser.add_argument('--opt_eps', default=1e-8, type=float, metavar='EPSILON',\n                        help='Optimizer Epsilon (default: 1e-8)')\n    parser.add_argument('--opt_betas', default=None, type=float, nargs='+', metavar='BETA',\n                        help='Optimizer Betas (default: 0.9, 0.98, use opt default)')\n    parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM',\n                        help='Clip gradient norm (default: None, no clipping)')\n    parser.add_argument('--momentum', type=float, default=0.9, metavar='M',\n                        help='SGD momentum (default: 0.9)')\n    parser.add_argument('--weight_decay', type=float, default=0.05,\n                        help='weight decay (default: 0.05)')\n    parser.add_argument('--weight_decay_end', type=float, default=None, help=\"\"\"Final value of the\n        weight decay. We use a cosine schedule for WD. \n        (Set the same value with args.weight_decay to keep weight decay no change)\"\"\")\n\n    parser.add_argument('--lr', type=float, default=5e-4, metavar='LR',\n                        help='learning rate (default: 5e-4)')\n    parser.add_argument('--warmup_lr', type=float, default=1e-6, metavar='LR',\n                        help='warmup learning rate (default: 1e-6)')\n    parser.add_argument('--min_lr', type=float, default=1e-5, metavar='LR',\n                        help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')\n\n    parser.add_argument('--warmup_epochs', type=int, default=5, metavar='N',\n                        help='epochs to warmup LR, if scheduler supports')\n    parser.add_argument('--warmup_steps', type=int, default=-1, metavar='N',\n                        help='epochs to warmup LR, if scheduler supports')\n\n    # Augmentation parameters\n    parser.add_argument('--train_interpolation', type=str, default='bicubic',\n                        help='Training interpolation (random, bilinear, bicubic default: \"bicubic\")')\n    parser.add_argument('--second_interpolation', type=str, default='lanczos',\n                        help='Interpolation for discrete vae (random, bilinear, bicubic default: \"lanczos\")')\n\n    # Dataset parameters\n    parser.add_argument('--data_path', default='/datasets01/imagenet_full_size/061417/', type=str,\n                        help='dataset path')\n    parser.add_argument('--imagenet_default_mean_and_std', default=False, action='store_true')\n\n    parser.add_argument('--output_dir', default='',\n                        help='path where to save, empty for no saving')\n    parser.add_argument('--log_dir', default=None,\n                        help='path where to tensorboard log')\n    parser.add_argument('--device', default='cuda',\n                        help='device to use for training / testing')\n    parser.add_argument('--seed', default=0, type=int)\n    parser.add_argument('--resume', default='', help='resume from checkpoint')\n    parser.add_argument('--auto_resume', action='store_true')\n    parser.add_argument('--no_auto_resume', action='store_false', dest='auto_resume')\n    parser.set_defaults(auto_resume=True)\n\n    parser.add_argument('--start_epoch', default=0, type=int, metavar='N',\n                        help='start epoch')\n    parser.add_argument('--num_workers', default=10, type=int)\n    parser.add_argument('--pin_mem', action='store_true',\n                        help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')\n    parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem',\n                        help='')\n    parser.set_defaults(pin_mem=True)\n\n    # distributed training parameters\n    parser.add_argument('--world_size', default=1, type=int,\n                        help='number of distributed processes')\n    parser.add_argument('--local_rank', default=-1, type=int)\n    parser.add_argument('--dist_on_itp', action='store_true')\n    parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')\n\n    parser.add_argument('--exp_name', default='', type=str, help='it is used when save the checkpoint')\n    parser.add_argument('--enable_multi_print', action='store_true',default=False, help='allow each gpu to print something')\n\n    '''\n    Data augmentation\n    '''\n    # crop size\n    parser.add_argument('--crop_min_size', type=float, default=0.08, help='min size of crop')\n    parser.add_argument('--crop_max_size', type=float, default=1.0, help='max size of crop')\n    # color jitter\n    parser.add_argument('--color_jitter', type=float, default=0, metavar='PCT', help='Color jitter factor (default: 0)')\n    \n    '''\n    Mask strategy\n    '''\n    parser.add_argument('--mask_generator', default='block', type=str,\n                        help='block or random')\n    # 1. if use block mask, set the num_mask_patches\n    parser.add_argument('--num_mask_patches', default=98, type=int,\n                        help='number of the visual tokens/patches need be masked')\n    parser.add_argument('--max_mask_patches_per_block', type=int, default=None)\n    parser.add_argument('--min_mask_patches_per_block', type=int, default=16)\n    # 2. if use random mask, set the mask ratio\n    parser.add_argument('--ratio_mask_patches', default=None, type=float, help=\"mask ratio\")\n\n    '''\n    CAE hyper-parameters\n    '''\n    parser.add_argument('--regressor_depth', default=4, type=int, help='depth of the regressor')\n    parser.add_argument('--decoder_depth', default=4, type=int, help='depth of the decoder')\n    parser.add_argument('--decoder_embed_dim', default=768, type=int,\n                        help='dimensionaltiy of embeddings for decoder')\n    parser.add_argument('--decoder_num_heads', default=12, type=int,\n                        help='Number of heads for decoder')\n    parser.add_argument('--decoder_num_classes', default=8192, type=int,\n                        help='Number of classes for decoder')\n    parser.add_argument('--decoder_layer_scale_init_value', default=0.1, type=float,\n                        help='decoder layer scale init value')\n\n    # alignment constraint\n    parser.add_argument('--align_loss_weight', type=float, default=2, help='loss weight for the alignment constraint')\n    parser.add_argument('--base_momentum', type=float, default=0, help='ema weight for the dual path network')\n\n    # init func, borrowed from BEiT\n    parser.add_argument('--fix_init_weight', action='store_true', default=False, help='if true, the fix_init_weight() func will be activated')\n\n\n    return parser.parse_args()\n\n\ndef get_model(args):\n    print(f\"Creating model: {args.model}\")\n    model = create_model(\n        args.model,\n        pretrained=False,\n        drop_path_rate=args.drop_path,\n        drop_block_rate=None,\n        use_abs_pos_emb=args.abs_pos_emb,\n        init_values=args.layer_scale_init_value,\n        args=args,\n    )\n\n    return model\n\n\ndef main(args):\n    utils.init_distributed_mode(args)\n\n    print(args)\n\n    device = torch.device(args.device)\n\n    # fix the seed for reproducibility\n    seed = args.seed + utils.get_rank()\n    torch.manual_seed(seed)\n    np.random.seed(seed)\n\n    cudnn.benchmark = True\n\n    model = get_model(args)\n    patch_size = model.encoder.patch_embed.patch_size\n    print(\"Patch size = %s\" % str(patch_size))\n    args.window_size = (args.input_size // patch_size[0], args.input_size // patch_size[1])\n    args.patch_size = patch_size\n\n    # get dataset\n    dataset_train = build_cae_pretraining_dataset(args)\n\n    # prepare discrete vae\n    d_vae = utils.create_d_vae(\n        weight_path=args.discrete_vae_weight_path, d_vae_type=args.discrete_vae_type,\n        device=device, image_size=args.second_input_size, args=args)\n\n    if True:  # args.distributed:\n        num_tasks = utils.get_world_size()\n        global_rank = utils.get_rank()\n        sampler_rank = global_rank\n        num_training_steps_per_epoch = len(dataset_train) // args.batch_size // num_tasks\n\n        sampler_train = torch.utils.data.DistributedSampler(\n            dataset_train, num_replicas=num_tasks, rank=sampler_rank, shuffle=True\n        )\n        print(\"Sampler_train = %s\" % str(sampler_train))\n    else:\n        sampler_train = torch.utils.data.RandomSampler(dataset_train)\n\n    if global_rank == 0 and args.log_dir is not None:\n        os.makedirs(args.log_dir, exist_ok=True)\n        log_writer = utils.TensorboardLogger(log_dir=args.log_dir)\n    else:\n        log_writer = None\n\n    data_loader_train = torch.utils.data.DataLoader(\n        dataset_train, sampler=sampler_train,\n        batch_size=args.batch_size,\n        num_workers=args.num_workers,\n        pin_memory=args.pin_mem,\n        drop_last=True,\n    )\n\n    model.to(device)\n    model_without_ddp = model\n    n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)\n\n    print(\"Model = %s\" % str(model_without_ddp))\n    print('number of params:', n_parameters)\n\n    total_batch_size = args.batch_size * utils.get_world_size()\n    print(\"LR = %.8f\" % args.lr)\n    print(\"Batch size = %d\" % total_batch_size)\n    print(\"Number of training steps = %d\" % num_training_steps_per_epoch)\n    print(\"Number of training examples per epoch = %d\" % (total_batch_size * num_training_steps_per_epoch))\n\n    if args.distributed:\n        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)\n        model_without_ddp = model.module\n\n    optimizer = create_optimizer(\n        args, model_without_ddp)\n    loss_scaler = NativeScaler()\n\n    print(\"Use step level LR & WD scheduler!\")\n    lr_schedule_values = utils.cosine_scheduler(\n        args.lr, args.min_lr, args.epochs, num_training_steps_per_epoch,\n        warmup_epochs=args.warmup_epochs, warmup_steps=args.warmup_steps,\n    )\n    if args.weight_decay_end is None:\n        args.weight_decay_end = args.weight_decay\n    wd_schedule_values = utils.cosine_scheduler(\n        args.weight_decay, args.weight_decay_end, args.epochs, num_training_steps_per_epoch)\n    print(\"Max WD = %.7f, Min WD = %.7f\" % (max(wd_schedule_values), min(wd_schedule_values)))\n\n    utils.auto_load_model(\n        args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler)\n\n    print(f\"Start training for {args.epochs} epochs\")\n    start_time = time.time()\n\n    for epoch in range(args.start_epoch, args.epochs):\n        if args.distributed:\n            data_loader_train.sampler.set_epoch(epoch)\n        if log_writer is not None:\n            log_writer.set_step(epoch * num_training_steps_per_epoch)\n                \n        train_stats = train_one_epoch(\n            model, d_vae, data_loader_train,\n            optimizer, device, epoch, loss_scaler,\n            args.clip_grad, log_writer=log_writer,\n            start_steps=epoch * num_training_steps_per_epoch,\n            lr_schedule_values=lr_schedule_values,\n            wd_schedule_values=wd_schedule_values,\n            args=args,\n        )\n        if args.output_dir:\n            if (epoch + 1) % args.save_ckpt_freq == 0 or epoch + 1 == args.epochs:\n                utils.save_model(\n                    args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,\n                    loss_scaler=loss_scaler, epoch=epoch, exp_name=args.exp_name)\n\n        log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 'epoch': epoch, 'n_parameters': n_parameters}\n\n        if args.output_dir and utils.is_main_process():\n            if log_writer is not None:\n                log_writer.flush()\n            with open(os.path.join(args.output_dir, \"log.txt\"), mode=\"a\", encoding=\"utf-8\") as f:\n                f.write(json.dumps(log_stats) + \"\\n\")\n\n    total_time = time.time() - start_time\n    total_time_str = str(datetime.timedelta(seconds=int(total_time)))\n    print('Training time {}'.format(total_time_str))\n\n\nif __name__ == '__main__':\n    opts = get_args()\n    if opts.output_dir:\n        Path(opts.output_dir).mkdir(parents=True, exist_ok=True)\n    main(opts)\n"
  }
]