[
  {
    "path": "HiFi_Net.py",
    "content": "# ------------------------------------------------------------------------------\n# Author: Xiao Guo (guoxia11@msu.edu)\n# CVPR2023: Hierarchical Fine-Grained Image Forgery Detection and Localization\n# ------------------------------------------------------------------------------\nfrom utils.utils import *\nfrom utils.custom_loss import IsolatingLossFunction, load_center_radius_api\nfrom models.seg_hrnet import get_seg_model\nfrom models.seg_hrnet_config import get_cfg_defaults\nfrom models.NLCDetection_api import NLCDetection\nfrom PIL import Image\n\nimport torch\nimport torch.nn as nn\nimport numpy as np\nimport argparse\nimport imageio as imageio\n\nclass HiFi_Net():\n    '''\n        FENET is the multi-branch feature extractor.\n        SegNet contains the classification and localization modules.\n        LOSS_MAP is the classification loss function class.\n    '''\n    def __init__(self):\n        device = torch.device('cuda:0')\n        device_ids = [0]\n\n        FENet_cfg = get_cfg_defaults()\n        FENet  = get_seg_model(FENet_cfg).to(device) # load the pre-trained model inside.\n        SegNet = NLCDetection().to(device)\n        FENet  = nn.DataParallel(FENet)\n        SegNet = nn.DataParallel(SegNet)\n\n        self.FENet  = restore_weight_helper(FENet,  \"weights/HRNet\",  750001)\n        self.SegNet = restore_weight_helper(SegNet, \"weights/NLCDetection\", 750001)\n        self.FENet.eval()\n        self.SegNet.eval()\n\n        center, radius = load_center_radius_api()\n        self.LOSS_MAP = IsolatingLossFunction(center,radius).to(device)\n\n    def _transform_image(self, image_name):\n        '''transform the image.'''\n        image = imageio.imread(image_name)\n        image = Image.fromarray(image)\n        image = image.resize((256,256), resample=Image.BICUBIC)\n        image = np.asarray(image)\n        image = image.astype(np.float32) / 255.\n        image = torch.from_numpy(image)\n        image = image.permute(2, 0, 1)\n        image = torch.unsqueeze(image, 0)\n        return image\n\n    def _normalized_threshold(self, res, prob, threshold=0.5, verbose=False):\n        '''to interpret detection result via omitting the detection decision.'''\n        if res > threshold:\n            decision = \"Forged\"\n            prob = (prob - threshold) / threshold\n        else:\n            decision = 'Real'\n            prob = (threshold - prob) / threshold\n        print(f'Image being {decision} with the confidence {prob*100:.1f}.')\n\n    def detect(self, image_name, verbose=False):\n        \"\"\"\n            Para: image_name is string type variable for the image name.\n            Return:\n                res: binary result for real and forged.\n                prob: the prob being the forged image.\n        \"\"\"\n        with torch.no_grad():\n            img_input = self._transform_image(image_name)\n            output = self.FENet(img_input)\n            mask1_fea, mask1_binary, out0, out1, out2, out3 = self.SegNet(output, img_input)\n            res, prob = one_hot_label_new(out3)\n            res = level_1_convert(res)[0]\n            if not verbose:\n                return res, prob[0]\n            else:\n                self._normalized_threshold(res, prob[0])\n\n    def localize(self, image_name):\n        \"\"\"\n            Para: image_name is string type variable for the image name.\n            Return:\n                binary_mask: forgery mask.\n        \"\"\"\n        with torch.no_grad():\n            img_input = self._transform_image(image_name)\n            output = self.FENet(img_input)\n            mask1_fea, mask1_binary, out0, out1, out2, out3 = self.SegNet(output, img_input)\n            pred_mask, pred_mask_score = self.LOSS_MAP.inference(mask1_fea)   # inference\n            pred_mask_score = pred_mask_score.cpu().numpy()\n            ## 2.3 is the threshold used to seperate the real and fake pixels.\n            ## 2.3 is the dist between center and pixel feature in the hyper-sphere.\n            ## for center and pixel feature please refer to \"IsolatingLossFunction\" in custom_loss.py\n            pred_mask_score[pred_mask_score<2.3] = 0.\n            pred_mask_score[pred_mask_score>=2.3] = 1.\n            binary_mask = pred_mask_score[0]        \n            return binary_mask\n\n\ndef inference(img_path):\n    HiFi = HiFi_Net()   # initialize\n    \n    ## detection\n    res3, prob3 = HiFi.detect(img_path)\n    # print(res3, prob3) 1 1.0\n    HiFi.detect(img_path, verbose=True)\n    \n    ## localization\n    binary_mask = HiFi.localize(img_path)\n    binary_mask = Image.fromarray((binary_mask*255.).astype(np.uint8))\n    binary_mask.save('pred_mask.png')\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--img_path', type=str, default='asset/sample_1.jpg')\n    args = parser.parse_args()\n    inference(args.img_path)\n"
  },
  {
    "path": "HiFi_Net_loc.py",
    "content": "# ------------------------------------------------------------------------------\n# Author: Xiao Guo (guoxia11@msu.edu)\n# CVPR2023: Hierarchical Fine-Grained Image Forgery Detection and Localization\n# ------------------------------------------------------------------------------\nfrom utils.utils import *\nfrom IMD_dataloader import *\nfrom utils.custom_loss import IsolatingLossFunction, load_center_radius\nfrom torch.utils.tensorboard import SummaryWriter\nfrom tqdm import tqdm\nfrom models.seg_hrnet import get_seg_model\nfrom models.seg_hrnet_config import get_cfg_defaults\nfrom models.NLCDetection_loc import NLCDetection\n\nfrom sklearn import metrics\nfrom sklearn.metrics import roc_auc_score\n\nfrom torchvision.utils import make_grid\nfrom einops import rearrange\nfrom PIL import Image\nfrom sklearn import metrics\n\nimport os\nimport csv\nimport time\nimport torch\nimport torch.nn as nn\nimport argparse\nimport numpy as np\n\ndevice = torch.device('cuda:0')\ndevice_ids = [0]\n\ndef config(args):\n    '''Set up input configurations.'''\n    args.crop_size = [args.crop_size, args.crop_size]\n    # cuda_list = args.list_cuda\n    global device \n    device = torch.device('cuda:0')\n    # global device_ids\n    # device_ids = device_ids_return(cuda_list)\n\n    args.save_dir    = 'lr_' + str(args.learning_rate)+'_loc'\n    FENet_dir, SegNet_dir = args.save_dir+'/HRNet', args.save_dir+'/NLCDetection'\n    FENet_cfg = get_cfg_defaults()\n    FENet  = get_seg_model(FENet_cfg).to(device) # load the pre-trained model inside.\n    SegNet = NLCDetection().to(device)\n\n    FENet  = nn.DataParallel(FENet, device_ids=device_ids)\n    SegNet = nn.DataParallel(SegNet, device_ids=device_ids)\n\n    writer = None\n\n    return args, writer, FENet, SegNet, FENet_dir, SegNet_dir\n\ndef restore_weight(args, FENet, SegNet, FENet_dir, SegNet_dir):\n    '''load FENet, SegNet and optimizer.'''\n    params      = list(FENet.parameters()) + list(SegNet.parameters()) \n    optimizer   = torch.optim.Adam(params, lr=args.learning_rate)\n    initial_epoch = findLastCheckpoint(save_dir=SegNet_dir)\n\n    # load FENet and SegNet weight:\n    FENet  = restore_weight_helper(FENet,  FENet_dir,  initial_epoch)\n    SegNet = restore_weight_helper(SegNet, SegNet_dir, initial_epoch)\n    optimizer  = restore_optimizer(optimizer, SegNet_dir)\n\n    return optimizer, initial_epoch\n\ndef Inference_loc(\n                args, FENet, SegNet, LOSS_MAP, tb_writer, \n                iter_num=None, \n                save_tag=False, \n                localization=True\n                ):\n    '''\n        the inference pipeline for the pre-trained model.\n        the image-level detection will dump to the csv file.\n        the pixel-level localization will be saved as in the npy file.\n    '''\n\n    for val_tag in [0,1,2,3,4]:\n\n        val_data_loader, data_label = eval_dataset_loader_init(args, val_tag)\n        print(f\"working on the dataset: {data_label}.\")\n        F1_lst, auc_lst = [], []\n        with torch.no_grad():\n            FENet.eval()\n            SegNet.eval()\n            for step, val_data in enumerate(tqdm(val_data_loader)):\n                image, mask, cls, image_names = val_data\n                image, mask = image.to(device), mask.to(device)\n                mask = torch.squeeze(mask, axis=1)\n\n                # model\n                try:\n                    output = FENet(image)\n                    mask1_fea, mask_binary, out0, out1, out2, out3 = SegNet(output, image)\n                except:\n                    print(f\"does not work on the \", image_names)\n                    continue\n                if args.loss_type == 'dm':\n                    loss_map, loss_manip, loss_nat = LOSS_MAP(mask1_fea, mask)\n                    pred_mask = LOSS_MAP.dis_curBatch.squeeze(dim=1)\n                    pred_mask_score = LOSS_MAP.dist.squeeze(dim=1)\n                elif args.loss_type == 'ce':\n                    pred_mask_score = mask_binary\n                    pred_mask = torch.zeros_like(mask_binary)\n                    pred_mask[mask_binary > 0.5] = 1\n                    pred_mask[mask_binary <= 0.5] = 0\n                viz_log(args, mask, pred_mask, image, iter_num, f\"{step}_{val_tag}\", mode='eval')\n\n                mask = torch.unsqueeze(mask, axis=1)\n                for img_idx, cur_img_name in enumerate(image_names):\n\n                    mask_ = torch.unsqueeze(mask[img_idx,0], 0)\n                    pred_mask_ = torch.unsqueeze(pred_mask[img_idx], 0)\n                    pred_mask_score_ = torch.unsqueeze(pred_mask_score[img_idx], 0)\n\n                    mask_ = mask_.cpu().clone().cpu().numpy().reshape(-1)\n                    pred_mask_ = pred_mask_.cpu().clone().cpu().numpy().reshape(-1)\n                    pred_mask_score_ = pred_mask_score_.cpu().clone().cpu().numpy().reshape(-1)\n\n                    F1_a  = metrics.f1_score(mask_, pred_mask_, average='macro')\n                    auc_a = metrics.roc_auc_score(mask_, pred_mask_score_)\n\n                    pred_mask_[np.where(pred_mask_ == 0)] = 1\n                    pred_mask_[np.where(pred_mask_ == 1)] = 0\n\n                    F1_b  = metrics.f1_score(mask_, pred_mask_, average='macro')\n                    if F1_a > F1_b:\n                        F1 = F1_a\n                    else:\n                        F1 = F1_b\n                    F1_lst.append(F1)\n                    AUC_score = auc_a if auc_a > 0.5 else 1-auc_a\n                    auc_lst.append(AUC_score)\n                    \n        print(\"F1: \", np.mean(F1_lst))\n        print(\"AUC: \", np.mean(auc_lst))\n\ndef main(args):\n    ## Set up the configuration.\n    args, writer, FENet, SegNet, FENet_dir, SegNet_dir = config(args)\n\n    ## load FENet and SegNet weight:\n    if args.loss_type == 'ce':\n        FENet  = restore_weight_helper(FENet,  \"weights/HRNet\",  225000)\n        SegNet = restore_weight_helper(SegNet, \"weights/NLCDetection\", 225000)    \n    elif args.loss_type == 'dm':\n        FENet  = restore_weight_helper(FENet,  \"weights/HRNet\",  315000)\n        SegNet = restore_weight_helper(SegNet, \"weights/NLCDetection\", 315000)\n    else:\n        raise ValueError\n\n    ## Set up the loss function.\n    center, radius = load_center_radius(args, FENet, SegNet, \n                                        train_data_loader=None, \n                                        center_radius_dir='./center_loc')\n    CE_loss  = nn.CrossEntropyLoss().to(device)\n    BCE_loss = nn.BCELoss(reduction='none').to(device)\n    LOSS_MAP = IsolatingLossFunction(center,radius).to(device)\n\n    Inference_loc(\n                args, \n                FENet, \n                SegNet,\n                LOSS_MAP,\n                tb_writer=writer, \n                iter_num=99999, \n                save_tag=True, \n                localization=True\n                )\n    print(\"after saving the points...\")\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument('-l','--list_cuda', nargs='+', help='<Required> Set flag')\n    parser.add_argument('-lr', '--learning_rate', type=float, default=5e-5)\n    parser.add_argument('--num_epochs', type=int, default=3)\n    parser.add_argument('--lr_gamma', type=float, default=2.0)\n    parser.add_argument('--lr_backbone', type=float, default=0.9)\n    parser.add_argument('--patience', type=int, default=30)\n    parser.add_argument('--step_factor', type=float, default=0.95)\n    parser.add_argument('--dis_step', type=int, default=50)\n    parser.add_argument('--val_step', type=int, default=500)\n\n    ## train hyper-parameters\n    parser.add_argument('--crop_size', type=int, default=256)\n    parser.add_argument('--val_num', type=int, default=200, help='val sample number.')\n    parser.add_argument('--train_num', type=int, default=360000, help='train sample number.')\n    parser.add_argument('--train_tag', type=int, default=0)\n    parser.add_argument('--val_tag', type=int, default=0)\n    parser.add_argument('--val_all', type=int, default=1)\n    parser.add_argument('--ablation', type=str, default='local', \n                            choices=['base', 'fg', 'local', 'full'], \n                            help='exp for one-shot, fine_grain, plus localization, plus pconv')\n    parser.add_argument('--val_loc_tag', action='store_true')\n    parser.add_argument('--fine_tune', action='store_true')\n    parser.add_argument('--debug_mode', action='store_true')\n    parser.set_defaults(val_loc_tag=True)\n    parser.set_defaults(fine_tune=True)\n\n    parser.add_argument('--train_ratio', nargs='+', default=\"0.4 0.4 0.2\", help='deprecated')\n    parser.add_argument('--path', type=str, default=\"\", help='deprecated')\n    parser.add_argument('--train_bs', type=int, default=10, help='batch size in the training.')\n    parser.add_argument('--val_bs', type=int, default=10, help='batch size in the validation.')\n    parser.add_argument('--percent', type=float, default=1.0, help='label dataset.')\n    parser.add_argument('--loss_type', type=str, default='ce',\n                            choices=['ce', 'dm'], help='ce or deep metric.')\n\n    ## inference hyperparameters:\n    parser.add_argument('--initial_epoch', type=int, default=70500)\n    args = parser.parse_args()\n    main(args)\n"
  },
  {
    "path": "HiFi_Net_loc.sh",
    "content": "source ~/.bashrc\nconda activate HiFi_Net\nCUDA_NUM=2\nCUDA_VISIBLE_DEVICES=$CUDA_NUM python HiFi_Net_loc.py "
  },
  {
    "path": "IMD_dataloader.py",
    "content": "# ------------------------------------------------------------------------------\n# Author: Xiao Guo, Xiaohong Liu.\n# CVPR2023: Hierarchical Fine-Grained Image Forgery Detection and Localization\n# ------------------------------------------------------------------------------\nfrom torch.utils.data import DataLoader\nfrom utils.load_data import TrainData, ValData\nfrom utils.load_edata import *\n\ndef train_dataset_loader_init(args):\n\ttrain_dataset = TrainData(args)\n\ttrain_data_loader = DataLoader(\n\t\t\t\t\t\t\t\ttrain_dataset, \n\t\t\t\t\t\t\t\tbatch_size=args.train_bs, \n\t\t\t\t\t\t\t\tshuffle=True, \n\t\t\t\t\t\t\t\t# shuffle=False,\n\t\t\t\t\t\t\t\tnum_workers=8\n\t\t\t\t\t\t\t\t)\n\treturn train_data_loader\n\ndef infer_dataset_loader_init(args, shuffle=True, bs=8):\n\tval_dataset = ValData(args)\n\tval_data_loader = DataLoader(\n\t\t\t\t\t\t\t\tval_dataset, \n\t\t\t\t\t\t\t\tbatch_size=bs,\n\t\t\t\t\t\t\t\tshuffle=shuffle, \n\t\t\t\t\t\t\t\t# shuffle=True, \n\t\t\t\t\t\t\t\tnum_workers=8\n\t\t\t\t\t\t\t\t)\n\treturn val_data_loader\n\ndef eval_dataset_loader_init(args, val_tag, batch_size=1):\n\t\n\tif val_tag == 0:\n\t\tdata_label = 'columbia'\n\t\tval_data_loader = DataLoader(ValColumbia(args), batch_size=batch_size, shuffle=False,\n\t\t\t\t\t\t\t\t\t num_workers=0)\n\telif val_tag == 1:\n\t\tdata_label = 'coverage'\n\t\tval_data_loader = DataLoader(ValCoverage(args), batch_size=batch_size, shuffle=False,\n\t\t\t\t\t\t\t\t\t num_workers=0)\n\telif val_tag == 2:\n\t\tdata_label = 'casia'\n\t\tval_data_loader = DataLoader(ValCasia(args), batch_size=batch_size, shuffle=False,\n\t\t\t\t\t\t\t\t\t num_workers=0)\n\telif val_tag == 3:\n\t\tdata_label = 'NIST16'\n\t\tval_data_loader = DataLoader(ValNIST16(args), batch_size=batch_size, shuffle=False,\n\t\t\t\t\t\t\t\t\t num_workers=0)\n\telif val_tag == 4:\n\t\tdata_label = 'IMD2020'\n\t\tval_data_loader = DataLoader(ValIMD2020(args), batch_size=batch_size, shuffle=False,\n\t\t\t\t\t\t\t\t\t num_workers=0)\n\treturn val_data_loader, data_label"
  },
  {
    "path": "LICENSE",
    "content": "MIT License\n\nCopyright (c) 2023 Xiao Guo\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "README.md",
    "content": "# HiFi_IFDL\n\nThis is the source code for our CVPR $2023$: \"*Hierarchical Fine-Grained Image Forgery Detection and Localization*.\" [[Arxiv]](https://arxiv.org/pdf/2303.17111.pdf)\n\nAuthors: [Xiao Guo](https://scholar.google.com/citations?user=Gkc-lAEAAAAJ&hl=en), [Xiaohong Liu](https://jhc.sjtu.edu.cn/~xiaohongliu/), [Zhiyuan Ren](https://scholar.google.com/citations?user=Z1ltuXEAAAAJ&hl=en), [Steven Grosz](https://scholar.google.com/citations?user=I1wOjTYUyYAC&hl=en), [Iacopo Masi](https://iacopomasi.github.io/), [Xiaoming Liu](http://cvlab.cse.msu.edu/)\n\n<p align=\"center\">\n  <img src=\"https://github.com/CHELSEA234/HiFi_IFDL/blob/main/figures/overview_4.png\" alt=\"drawing\" width=\"1000\"/>\n</p>\n\n### <a name=\"update\"></a> Updates.\n- [Sep 2024] 👏 The International Journal of Computer Vision (**IJCV**) has accepted the extended version of HiFi-Net, stay tuned~\n- [Aug 2024] The HiFi-Net is integrated into the DeepFake-o-meter v2.0 platform, which is a user-friendly public detection tool designed by the **University at Buffalo**. [[DeepFake-o-meter v2.0]](https://zinc.cse.buffalo.edu/ubmdfl/deep-o-meter/home_login) [[ArXiv]](https://arxiv.org/pdf/2404.13146)\n- [Jul. 2024] 👏 **ECCV2024** \"Deepfake Explainer\" paper [[ArXiv]](https://arxiv.org/pdf/2402.00126) reports HiFi-Net's deep fake detection performance and the source code is released [[link]](https://github.com/CHELSEA234/HiFi_IFDL/edit/main/applications/deepfake_detection).\n- [Sep 2023] The first version dataset can be acquired via this link: [Dataset Link](https://drive.google.com/drive/folders/1fwBEmW30-e0ECpCNNG3nRU6I9OqJfMAn?usp=sharing)\n- [June 2023] The extended version of our work has been submitted to one of the ~~Machine Learning Journals~~ IJCV.\n- **This GitHub will keep updated, please stay tuned~**\n\n### Short 5 Min Video \n[![Please Click the Figure](https://github.com/CHELSEA234/HiFi_IFDL/blob/main/figures/architecture.png)](https://www.youtube.com/watch?v=FwS3X5xcj8A&list=LL&index=5)\n\n### Usage on Manipulation Localization (_e.g._, Columbia, Coverage, CASIA, NIST16 and IMD2020)\n- To create your environment by\n  ```\n  conda env create -f environment.yml\n  ```\n  or mannually install `pytorch 1.11.0` and `torchvision 0.12.0` in `python 3.7.16`.\n- Go to [localization_weights_link](https://drive.google.com/drive/folders/1cxCoE2hjcDj4lLrJmGEbskzPRJfoDIMJ?usp=sharing) to download the weights from, and then put them in `weights`.\n- To apply the pre-trained model on images in the `./data_dir` and then obtain results in `./viz_eval`, please run\n  ```\n  bash HiFi_Net_loc.sh\n  ```\n- More quantitative and qualitative results can be found at: [csv](https://drive.google.com/drive/folders/12iS0ILb6ndXtdWjonByrgnejzuAvwCqp?usp=sharing) and [qualitative results](https://drive.google.com/drive/folders/1iZp6ciOHSbGq4EsC_AYl7zVK24gBtrd1?usp=sharing).\n- If you would like to generate the above result. Download $5$ datasets via [link](https://drive.google.com/file/d/1RYXTg0Q82KEvkeOtaaR5AZ0FBx5219SY/view?usp=sharing) and unzip it by `tar -xvf data.tar.gz`. Then, uncomment this [line](https://github.com/CHELSEA234/HiFi_IFDL/blob/main/utils/load_edata.py#L21) and run `HiFi_Net_loc.sh`. \n\n### Usage on Detecting and Localization for the general forged content including GAN and diffusion-generated images:\n- This reproduces detection and localization results in the HiFi-IFDL dataset (Tab. 2 and Supplementary Fig.1)\n- Go to [HiFi_IFDL_weights_link](https://drive.google.com/drive/folders/1v07aJ2hKmSmboceVwOhPvjebFMJFHyhm?usp=sharing) to download the weights, and then put them in `weights`. \n- The quick usage on HiFi_Net:\n```python\n  from HiFi_Net import HiFi_Net \n  from PIL import Image\n  import numpy as np\n\n  HiFi = HiFi_Net()   # initialize\n  img_path = 'asset/sample_1.jpg'\n\n  ## detection\n  res3, prob3 = HiFi.detect(img_path)\n  # print(res3, prob3) 1 1.0\n  HiFi.detect(img_path, verbose=True)\n\n  ## localization\n  binary_mask = HiFi.localize(img_path)\n  binary_mask = Image.fromarray((binary_mask*255.).astype(np.uint8))\n  binary_mask.save('pred_mask.png')\n```\n\n### Quick Start of Source Code\nA quick view of the code structure:\n```bash\n./HiFi_IFDL\n    ├── HiFi_Net_loc.py (localization files)\n    ├── HiFi_Net_loc.sh (localization evaluation)\n    ├── HiFi_Net.py (API for the user input image.)\n    ├── IMD_dataloader.py (call dataloaders in the utils folder)\n    ├── model (model module folder)\n    │      ├── NLCDetection_pconv.py (partial convolution, localization, and classification modules)\n    │      ├── seg_hrnet.py (feature extractor based on HRNet)\n    │      ├── LaPlacianMs.py (laplacian filter on the feature map)\n    │      ├── GaussianSmoothing.py (self-made smoothing functions)\n    │      └── ...   \n    ├── utils (utils, dataloader, and localization loss class.)\n    │      ├── custom_loss.py (localization loss class and the real pixel center initialization)\n    │      ├── utils.py\n    │      ├── load_data.py (loading training and val dataset.)\n    │      └── load_edata.py (loading inference dataset.)\n    ├── asset (folder contains sample images with their ground truth and predictions.)\n    ├── weights (put the pre-trained weights in.)\n    ├── center (The pre-computed `.pth` file for the HiFi-IFDL dataset.)\n    └── center_loc (The pre-computed `.pth` file for the localization task (Tab.3 in the paper).)\n```\n\n### Question and Answers.\nQ1. Why train and val datasets are in the same path? \n\nA1. For each forgery method, we save both train and val in the SAME folder, from which we use a text file to obtain the training and val images. The text file contains a list of image names, and the first `val_num` are used for training and the last \"val_num\" for validation. Specifically, refer to [code](https://github.com/CHELSEA234/HiFi_IFDL/blob/main/utils/load_data.py#L271) for details. What is more, we build up the code on the top of the PSCC-Net, which adapts the same style of loading data, please compare [code1](https://github.com/proteus1991/PSCC-Net/blob/main/utils/load_tdata.py#L88) with [code2](https://github.com/proteus1991/PSCC-Net/blob/main/utils/load_tdata.py#L290).\n\nQ2. What is the dataset naming for STGAN and the face-shifter section?\n\nA2. Please check the STGAN.txt in this [link](https://drive.google.com/drive/folders/1OIUv7OGxfAyerMnmKvrNnN_5CmIDcNxo?usp=sharing), which contains all manipulated/modified images we have used for training and validation. This txt file will be loaded by this line of [code](https://github.com/CHELSEA234/HiFi_IFDL/blob/main/utils/load_data.py#L163), which says about the corresponding masks. Lastly, I am not sure if I have release the authentic images, if I do not, you can simply find them in the public celebAHQ dataset. I will try to offer the rigid naming for the dataset in the near future. \n\n### Reference\nIf you would like to use our work, please cite:\n```Bibtex\n@inproceedings{hifi_net_xiaoguo,\n  author = { Xiao Guo and Xiaohong Liu and Zhiyuan Ren and Steven Grosz and Iacopo Masi and Xiaoming Liu },\n  title = { Hierarchical Fine-Grained Image Forgery Detection and Localization },\n  booktitle = { CVPR },\n  year = { 2023 },\n}\n```\n"
  },
  {
    "path": "applications/CNNImage_detection/README.md",
    "content": ""
  },
  {
    "path": "applications/DiffVideo_detection/README.md",
    "content": ""
  },
  {
    "path": "applications/deepfake_detection/FF++/put_weight_here",
    "content": ""
  },
  {
    "path": "applications/deepfake_detection/README.md",
    "content": "# HiFi_Deepfake\r\n\r\nWe apply the HiFi_Net for the deepfake detection as the following diagram:\r\n\r\n<p align=\"center\">\r\n  <img src=\"https://github.com/CHELSEA234/HiFi_IFDL/blob/main/figures/HiFi_deepfake.png\" alt=\"drawing\" width=\"1000\"/>\r\n</p>\r\n\r\n### Reported Performance\r\n<center>\r\n  \r\n| Dataset | AUC | Accuracy | EER | TPR@FPR=**$10$**% |TPR@FPR=**$1$**% | \r\n|:----:|:----:|:----:|:----:|:----:|:----:|\r\n|FF++(c40)|$92.10$|$89.16$|N/A|$74.44$|$40.85$\r\n|CelebDF|$68.80$|$67.20$|$36.13$|N/A|N/A\r\n|WildDeepfake|$65.22$|$66.29$|$38.65$|N/A|N/A\r\n  \r\n</center>\r\n\r\nMore results please refer to the table $3$ of our ECCV2024 paper [[ArXiv]](https://arxiv.org/pdf/2402.00126)\r\n\r\n### The Pre-trained Weights and User-friendly Preprocessed Dataset:\r\n1. The pre-trained weights on FF++ can be download via [[link]](https://drive.google.com/drive/folders/1AElYlVxsahgGIua3m3Kj2VhSc3S7ADLJ?usp=sharing)\r\n2. We offer a preprocessed FF++ dataset in the HDF5 file format [[link]](https://drive.google.com/drive/folders/1ovuurFCkBfmcMq7HKO5ph36U1QyL75UA?usp=sharing), supporting faster I/O. The dataset follows the naming ```FF++_{manipulation_type}_{compression rate}.h5``` and is structured as follows:\r\n```\r\nFF++_Deepfakes_c23.h5:\r\nFF++_Deepfakes_c40.h5\r\nFF++_Face2Face_c23.h5\r\nFF++_Face2Face_c40.h5\r\n...\r\n```\r\n\r\n### Quick Start\r\n1. Setup the environment using ```environment.yml```, then put the pre-trained weights in ```FF++``` folder.\r\n2. Download the entire dataset or a small portion of datasets, for example ```FF++_original_c40.h5``` and ```FF++_Deepfakes_c40.h5```.\r\n3. Run `bash test.sh` after setting up the data path [here](https://github.com/CHELSEA234/HiFi_IFDL/blob/main/applications/deepfake_detection/test.py#L106).\r\n4. If you choose to run the small portion dataset (e.g., ```FF++_original_c40.h5``` and ```FF++_Deepfakes_c40.h5```), please comment this [link](https://github.com/CHELSEA234/HiFi_IFDL/blob/main/applications/deepfake_detection/test.py#L34)\r\n\r\n### Quick View of Code\r\n```bash\r\n./deepfake_detection\r\n    ├── test.py (the inference code.)\r\n    ├── test.sh (run the inference code.)\r\n    ├── dataset_test.py (dataset tutorial)\r\n    ├── dataset_test.sh (dataset tutorial)\r\n    ├── train.py (the train code.)\r\n    ├── train.sh (run the train code.)\r\n    ├── exp_FF_c40_bs_32_lr_0.0001_ws_10.txt (The training log file.)\r\n    ├── FF++ (Please download the pre-trained weights and put it here)\r\n    ├── sequence (model module folder)\r\n    │      ├── rnn_stratified_dataloader.py (datalaoder)\r\n    │      ├── runjobs_utils.py (the first utility)\r\n    │      ├── torch_utils.py (the second utility)\r\n    │      └── models\r\n    │            ├── run_model.sh (model tutorial)\r\n    │            ├── LaPlacianMs.py\r\n    │            ├── HiFiNet_deepfake.py\r\n    │            └── ...\r\n    └── environment.yml\r\n```\r\n"
  },
  {
    "path": "applications/deepfake_detection/dataset_test.py",
    "content": "# coding: utf-8\n# author: Hierarchical Fine-Grained Image Forgery Detection and Localization\nimport os\nimport numpy as np\nimport subprocess\nimport logging\nimport sys\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport argparse\nimport datetime\n\nfrom tensorboardX import SummaryWriter\nfrom torch.optim.lr_scheduler import ReduceLROnPlateau\nsource_path = os.path.join('./sequence')\nsys.path.append(source_path)\nfrom rnn_stratified_dataloader import get_dataloader\nfrom models.HiFiNet_deepfake import HiFiNet_deepfake\nfrom torch_utils import eval_model,display_eval_tb,train_logging,lrSched_monitor\nfrom runjobs_utils import init_logger,Saver,DataConfig,torch_load_model\n\nlogger = init_logger(__name__)\nlogger.setLevel(logging.INFO)\n\nstarting_time = datetime.datetime.now()\n\n## Deterministic training\n_seed_id = 100\ntorch.backends.cudnn.deterministic = True\ntorch.manual_seed(_seed_id)\n\ndatasets = ['original', 'Deepfakes', 'FaceSwap', 'NeuralTextures', 'Face2Face']\n# datasets = ['original', 'Deepfakes']\nmanipulations_names = [n for c, n in enumerate(datasets) if n != 'original']\nmanipulations_dict = {n : c  for c, n in enumerate(manipulations_names) }\nmanipulations_dict['original'] = 255\n\nfor key, value in manipulations_dict.items():\n\tprint(key, value)\nctype = 'c40'\n\n# Create the parser\nparser = argparse.ArgumentParser(description='Process some integers.')\nparser.add_argument('--batch_size', type=int, default=4, help='input batch size for training (default: 32)')\nparser.add_argument('--window_size', type=int, default=5, help='size of the sliding window (default: 5)')\nparser.add_argument('--dataset_name', type=str, default=\"FF++\", help='size of the sliding window (default: 5)')\nparser.add_argument('--gpus', type=int, default=4, help='input batch size for training (default: 32)')\nparser.add_argument('--feat_dim', type=int, default=270, help='input dim to rnn. (default: 32)')\nparser.add_argument('--valid_epoch', type=int, default=2, help='val epoch')\nparser.add_argument('--display_step', type=int, default=50, help='display the loss value.')\nparser.add_argument('--learning_rate', type=float, default=1e-3, help='the used learning rate')\n\n# Parse the arguments\nargs = parser.parse_args()\n## Hyper-params #######################\nhparams = {\n            'epochs': 50, 'batch_size': args.batch_size, \n            'basic_lr': args.learning_rate, 'fine_tune': True, 'use_laplacian': True, \n            'step_factor': 0.1, 'patience': 20, 'weight_decay': 1e-06, 'lr_gamma': 2.0, 'use_magic_loss': True, \n            'feat_dim': args.feat_dim, 'drop_rate': 0.2, \n            'skip_valid': False, 'rnn_type': 'LSTM', 'rnn_hidden_size': 256, \n            'num_rnn_layers': 1, 'rnn_drop_rate': 0.2, \n            'bidir': False, 'merge_mode': 'concat', 'perc_margin_1': 0.95, 'perc_margin_2': 0.95, 'soft_boundary': False, \n            'dist_p': 2, 'radius_param': 0.84, 'strat_sampling': True, 'normalize': True, 'window_size': args.window_size, 'hop': 1, \n            'valid_epoch': args.valid_epoch, 'display_step': args.display_step, 'use_sched_monitor': True\n            }\nbatch_size = hparams['batch_size']\nbasic_lr = hparams['basic_lr']\nfine_tune = hparams['fine_tune']\nuse_laplacian = hparams['use_laplacian']\nstep_factor = hparams['step_factor']\npatience = hparams['patience']\nweight_decay = hparams['weight_decay']\nlr_gamma = hparams['lr_gamma']\nuse_magic_loss = hparams['use_magic_loss']\nfeat_dim = hparams['feat_dim']\ndrop_rate = hparams['drop_rate']\nrnn_type = hparams['rnn_type']\nrnn_hidden_size = hparams['rnn_hidden_size']\nnum_rnn_layers = hparams['num_rnn_layers']\nrnn_drop_rate = hparams['rnn_drop_rate']\nbidir = hparams['bidir']\nmerge_mode = hparams['merge_mode']\nperc_margin_1 = hparams['perc_margin_1']\nperc_margin_2 = hparams['perc_margin_2']\ndist_p = hparams['dist_p']\nradius_param = hparams['radius_param']\nstrat_sampling = hparams['strat_sampling']\nnormalize = hparams['normalize']\nwindow_size = hparams['window_size']\nhop = hparams['hop']\nsoft_boundary = hparams['soft_boundary']\nuse_sched_monitor = hparams['use_sched_monitor']\n########################################\nworkers_per_gpu = 6\ndataset_name = f\"{args.dataset_name}\"\nexp_name = f\"05_exp_c40_bs_{batch_size}_lr_{basic_lr}_ws_{window_size}\"\nmodel_name = exp_name\nmodel_path = os.path.join(f'./{dataset_name}', model_name)\nprint(f'Window_size: {args.window_size}; Dataset: {dataset_name}; Batch_Size: {batch_size}; LR: {basic_lr}.')\n\n# Create the model path if doesn't exists\nif not os.path.exists(model_path):\n    subprocess.call(f\"mkdir -p {model_path}\", shell=True)\n\n## Data Generation\nimg_path = \"/user/guoxia11/cvlshare/cvl-guoxia11/FaceForensics_HiFiNet\"\nbalanced_minibatch_opt = True\n\nif dataset_name == 'FF++':\n    train_generator, train_dataset = get_dataloader(\n                                                img_path, datasets, ctype, manipulations_dict, window_size, hop, \n                                                use_laplacian, normalize, strat_sampling, balanced_minibatch_opt, \n                                                'train', batch_size, workers=workers_per_gpu*args.gpus\n                                                )\n    test_generator, test_dataset = get_dataloader(\n                                                img_path, datasets, ctype, manipulations_dict, window_size, hop, \n                                                use_laplacian, normalize, strat_sampling, False, \n                                                'test', batch_size, workers=workers_per_gpu*args.gpus\n                                                )\n    # print(\"the dataset length is: \", len(train_dataset))\n    print(\"the dataloader length is: \", len(train_generator))\n    # del train_dataset\n    # del test_dataset\nelif dataset_name == \"CelebDF\":        \n    pass    ## TODO: will be released in the near future. \nelif dataset_name == 'DFW':\n    pass    ## TODO: will be released in the near future. \n\nprint('train: ', len(train_generator), len(train_dataset))\nprint('test: ', len(test_generator), len(test_dataset))\nfor ib, (img_batch_mmodal, true_labels, manip_type) in enumerate(train_generator,1):\n      print(img_batch_mmodal.size(), true_labels.size(), manip_type[:2])\n      if ib == 1:\n            break\nfor ib, (img_batch_mmodal, true_labels, manip_type) in enumerate(test_generator,1):\n      print(ib, img_batch_mmodal.size(), true_labels.size(), manip_type[:2])\n      if ib == 1:\n            break\nprint(\"...over...\")"
  },
  {
    "path": "applications/deepfake_detection/dataset_test.sh",
    "content": "source ~/.bashrc\nconda activate HiFi_Net_deepfake\nCUDA_NUM=7\nCUDA_VISIBLE_DEVICES=$CUDA_NUM python dataset_test.py \\\n                                    --dataset_name FF++ \\\n                                    --batch_size 32 \\\n                                    --window_size 10 \\\n                                    --gpus 1 \\\n                                    --valid_epoch 1 \\\n                                    --feat_dim 1000 \\\n                                    --learning_rate 1e-4 \\\n                                    --display_step 100"
  },
  {
    "path": "applications/deepfake_detection/environment.yml",
    "content": "name: HiFi_Net_deepfake\nchannels:\n  - pytorch\n  - conda-forge\n  - defaults\ndependencies:\n  - _libgcc_mutex=0.1=main\n  - _openmp_mutex=5.1=1_gnu\n  - absl-py=1.3.0=py37h06a4308_0\n  - aiohttp=3.8.3=py37h5eee18b_0\n  - aiosignal=1.2.0=pyhd3eb1b0_0\n  - async-timeout=4.0.2=py37h06a4308_0\n  - asynctest=0.13.0=py_0\n  - attrs=22.1.0=py37h06a4308_0\n  - blas=1.0=mkl\n  - blinker=1.4=py37h06a4308_0\n  - brotlipy=0.7.0=py37h27cfd23_1003\n  - bzip2=1.0.8=h7b6447c_0\n  - c-ares=1.19.1=h5eee18b_0\n  - ca-certificates=2023.12.12=h06a4308_0\n  - cachetools=4.2.2=pyhd3eb1b0_0\n  - certifi=2022.12.7=py37h06a4308_0\n  - cffi=1.15.1=py37h5eee18b_3\n  - charset-normalizer=2.0.4=pyhd3eb1b0_0\n  - click=8.0.4=py37h06a4308_0\n  - cryptography=39.0.1=py37h9ce1e76_0\n  - cudatoolkit=11.3.1=h2bc3f7f_2\n  - cycler=0.11.0=pyhd3eb1b0_0\n  - ffmpeg=4.3=hf484d3e_0\n  - fftw=3.3.9=h27cfd23_1\n  - freetype=2.12.1=h4a9f257_0\n  - frozenlist=1.3.3=py37h5eee18b_0\n  - giflib=5.2.1=h5eee18b_3\n  - gmp=6.2.1=h295c915_3\n  - gnutls=3.6.15=he1e5248_0\n  - google-auth=2.6.0=pyhd3eb1b0_0\n  - google-auth-oauthlib=0.4.4=pyhd3eb1b0_0\n  - grpcio=1.42.0=py37hce63b2e_0\n  - icu=67.1=he1b5a44_0\n  - idna=3.4=py37h06a4308_0\n  - imageio=2.9.0=pyhd3eb1b0_0\n  - importlib-metadata=4.11.3=py37h06a4308_0\n  - intel-openmp=2021.4.0=h06a4308_3561\n  - joblib=1.1.0=pyhd3eb1b0_0\n  - jpeg=9e=h5eee18b_1\n  - kiwisolver=1.4.4=py37h6a678d5_0\n  - lame=3.100=h7b6447c_0\n  - lcms2=2.12=h3be6417_0\n  - ld_impl_linux-64=2.38=h1181459_1\n  - lerc=3.0=h295c915_0\n  - libblas=3.9.0=12_linux64_mkl\n  - libcblas=3.9.0=12_linux64_mkl\n  - libdeflate=1.17=h5eee18b_1\n  - libffi=3.4.4=h6a678d5_0\n  - libgcc-ng=11.2.0=h1234567_1\n  - libgfortran-ng=11.2.0=h00389a5_1\n  - libgfortran5=11.2.0=h1234567_1\n  - libgomp=11.2.0=h1234567_1\n  - libiconv=1.16=h7f8727e_2\n  - libidn2=2.3.4=h5eee18b_0\n  - libpng=1.6.39=h5eee18b_0\n  - libprotobuf=3.20.3=he621ea3_0\n  - libstdcxx-ng=11.2.0=h1234567_1\n  - libtasn1=4.19.0=h5eee18b_0\n  - libtiff=4.5.1=h6a678d5_0\n  - libunistring=0.9.10=h27cfd23_0\n  - libuv=1.44.2=h5eee18b_0\n  - libwebp=1.2.4=h11a3e52_1\n  - libwebp-base=1.2.4=h5eee18b_1\n  - lz4-c=1.9.4=h6a678d5_0\n  - markdown=3.4.1=py37h06a4308_0\n  - markupsafe=2.1.1=py37h7f8727e_0\n  - matplotlib=3.2.2=1\n  - matplotlib-base=3.2.2=py37h1d35a4c_1\n  - mkl=2021.4.0=h06a4308_640\n  - mkl-service=2.4.0=py37h7f8727e_0\n  - mkl_fft=1.3.1=py37hd3c417c_0\n  - mkl_random=1.2.2=py37h51133e4_0\n  - multidict=6.0.2=py37h5eee18b_0\n  - ncurses=6.4=h6a678d5_0\n  - nettle=3.7.3=hbbd107a_1\n  - numpy=1.21.5=py37h6c91a56_3\n  - numpy-base=1.21.5=py37ha15fc14_3\n  - oauthlib=3.2.1=py37h06a4308_0\n  - openh264=2.1.1=h4ff587b_0\n  - openssl=1.1.1w=h7f8727e_0\n  - pillow=9.4.0=py37h6a678d5_0\n  - pip=23.3.2=pyhd8ed1ab_0\n  - protobuf=3.20.3=py37h6a678d5_0\n  - pyasn1=0.4.8=pyhd3eb1b0_0\n  - pyasn1-modules=0.2.8=py_0\n  - pycparser=2.21=pyhd3eb1b0_0\n  - pyjwt=2.4.0=py37h06a4308_0\n  - pyopenssl=23.0.0=py37h06a4308_0\n  - pyparsing=3.0.9=py37h06a4308_0\n  - pysocks=1.7.1=py37_1\n  - python=3.7.16=h7a1cb2a_0\n  - python-dateutil=2.8.2=pyhd3eb1b0_0\n  - python_abi=3.7=2_cp37m\n  - pytorch=1.11.0=py3.7_cuda11.3_cudnn8.2.0_0\n  - pytorch-mutex=1.0=cuda\n  - pyyaml=6.0=py37h5eee18b_1\n  - readline=8.2=h5eee18b_0\n  - requests=2.28.1=py37h06a4308_0\n  - requests-oauthlib=1.3.0=py_0\n  - rsa=4.7.2=pyhd3eb1b0_1\n  - scikit-learn=1.0.2=py37hf9e9bfc_0\n  - scipy=1.7.3=py37h6c91a56_2\n  - setuptools=68.2.2=pyhd8ed1ab_0\n  - six=1.16.0=pyhd3eb1b0_1\n  - sqlite=3.41.2=h5eee18b_0\n  - tensorboard=2.10.0=py37h06a4308_0\n  - tensorboard-data-server=0.6.1=py37h52d8a92_0\n  - tensorboard-plugin-wit=1.8.1=py37h06a4308_0\n  - threadpoolctl=2.2.0=pyh0d69192_0\n  - tk=8.6.12=h1ccaba5_0\n  - torchvision=0.12.0=py37_cu113\n  - tornado=5.1.1=py37h7b6447c_0\n  - tqdm=4.64.1=py37h06a4308_0\n  - typing-extensions=4.3.0=py37h06a4308_0\n  - typing_extensions=4.3.0=py37h06a4308_0\n  - urllib3=1.26.14=py37h06a4308_0\n  - werkzeug=2.2.2=py37h06a4308_0\n  - wheel=0.38.4=py37h06a4308_0\n  - xz=5.4.5=h5eee18b_0\n  - yacs=0.1.6=pyhd3eb1b0_1\n  - yaml=0.2.5=h7b6447c_0\n  - yarl=1.8.1=py37h5eee18b_0\n  - zipp=3.11.0=py37h06a4308_0\n  - zlib=1.2.13=h5eee18b_0\n  - zstd=1.5.5=hc292b87_0\n  - pip:\n    - einops==0.6.1\n    - h5py==3.8.0\n    - kmeans-pytorch==0.3\n    - opencv-python==4.8.1.78\n    - packaging==24.0\n    - tensorboardx==2.6.2.2\n\n"
  },
  {
    "path": "applications/deepfake_detection/exp_FF_c40_bs_32_lr_0.0001_ws_10.txt",
    "content": "AUC: 0.8829070609725371\nBest Accuracy: 0.8590476190476191 (Threshold: 0.46431525609451324)\nTPR at FPR=10.0%: 0.6581349206349206 (Score: 0.9032643437385559)\nTPR at FPR=1.0%: 0.33174603174603173 (Score: 0.9792982339859009)\nAverage Loss: 0.3208117520030077\n####################################################################################################AUC: 0.8959469482237339\nBest Accuracy: 0.8698412698412699 (Threshold: 0.4833123738989796)\nTPR at FPR=10.0%: 0.7043650793650794 (Score: 0.9370318651199341)\nTPR at FPR=1.0%: 0.35694444444444445 (Score: 0.9900425672531128)\nAverage Loss: 0.3183356274089535\n####################################################################################################AUC: 0.8979908352229781\nBest Accuracy: 0.8706349206349207 (Threshold: 0.057323044208089015)\nTPR at FPR=10.0%: 0.709920634920635 (Score: 0.7713479399681091)\nTPR at FPR=1.0%: 0.3773809523809524 (Score: 0.9895150661468506)\nAverage Loss: 0.44491299368715304\n####################################################################################################AUC: 0.9030002047115142\nBest Accuracy: 0.8752380952380953 (Threshold: 0.1729503843006701)\nTPR at FPR=10.0%: 0.7001984126984127 (Score: 0.9017165899276733)\nTPR at FPR=1.0%: 0.4263888888888889 (Score: 0.9916896820068359)\nAverage Loss: 0.3399755131538371\n####################################################################################################AUC: 0.8975945609725371\nBest Accuracy: 0.8757142857142857 (Threshold: 0.2071308652196435)\nTPR at FPR=10.0%: 0.6819444444444445 (Score: 0.9187996983528137)\nTPR at FPR=1.0%: 0.3384920634920635 (Score: 0.9935241937637329)\nAverage Loss: 0.3518789753537663\n####################################################################################################AUC: 0.8932506613756613\nBest Accuracy: 0.8668253968253968 (Threshold: 0.031224684557531163)\nTPR at FPR=10.0%: 0.6708333333333333 (Score: 0.5181991457939148)\nTPR at FPR=1.0%: 0.3732142857142857 (Score: 0.973039448261261)\nAverage Loss: 0.7130112742277117\n####################################################################################################AUC: 0.9064488063744018\nBest Accuracy: 0.8771428571428571 (Threshold: 0.41749477599181195)\nTPR at FPR=10.0%: 0.7198412698412698 (Score: 0.9881836771965027)\nTPR at FPR=1.0%: 0.35535714285714287 (Score: 0.9986361861228943)\nAverage Loss: 0.36584154823334025\n####################################################################################################AUC: 0.8815618701184177\nBest Accuracy: 0.8687301587301587 (Threshold: 0.472512484058953)\nTPR at FPR=10.0%: 0.629563492063492 (Score: 0.9977478384971619)\nTPR at FPR=1.0%: 0.2505952380952381 (Score: 0.9991620779037476)\nAverage Loss: 0.4665799030846784\n####################################################################################################AUC: 0.8969081475182665\nBest Accuracy: 0.8742857142857143 (Threshold: 0.3673135946369051)\nTPR at FPR=10.0%: 0.6833333333333333 (Score: 0.9759820699691772)\nTPR at FPR=1.0%: 0.2388888888888889 (Score: 0.9987502098083496)\nAverage Loss: 0.36425455615189173\n####################################################################################################AUC: 0.8948829207608969\nBest Accuracy: 0.8776190476190476 (Threshold: 0.6733377333917661)\nTPR at FPR=10.0%: 0.6797619047619048 (Score: 0.9928593635559082)\nTPR at FPR=1.0%: 0.21805555555555556 (Score: 0.9992702603340149)\nAverage Loss: 0.38597667283144527\n####################################################################################################AUC: 0.8982062547241119\nBest Accuracy: 0.8712698412698413 (Threshold: 0.015597607697074736)\nTPR at FPR=10.0%: 0.7041666666666667 (Score: 0.6901902556419373)\nTPR at FPR=1.0%: 0.36468253968253966 (Score: 0.9896582961082458)\nAverage Loss: 0.6102730501254047\n####################################################################################################AUC: 0.9000834593096498\nBest Accuracy: 0.8765079365079365 (Threshold: 0.1510543103770636)\nTPR at FPR=10.0%: 0.7001984126984127 (Score: 0.9870874881744385)\nTPR at FPR=1.0%: 0.2859126984126984 (Score: 0.9993632435798645)\nAverage Loss: 0.41046855883994304\n####################################################################################################AUC: 0.8990011652809271\nBest Accuracy: 0.8803174603174603 (Threshold: 0.10573415606968273)\nTPR at FPR=10.0%: 0.6757936507936508 (Score: 0.9892104864120483)\nTPR at FPR=1.0%: 0.2892857142857143 (Score: 0.9993498921394348)\nAverage Loss: 0.41615531263343497\n####################################################################################################AUC: 0.9041319444444444\nBest Accuracy: 0.8761904761904762 (Threshold: 0.04045753150121847)\nTPR at FPR=10.0%: 0.7011904761904761 (Score: 0.9294343590736389)\nTPR at FPR=1.0%: 0.32063492063492066 (Score: 0.9988333582878113)\nAverage Loss: 0.44224551875087514\n####################################################################################################AUC: 0.8955598072562357\nBest Accuracy: 0.8823809523809524 (Threshold: 0.08566011418851711)\nTPR at FPR=10.0%: 0.6718253968253968 (Score: 0.9880918264389038)\nTPR at FPR=1.0%: 0.27996031746031746 (Score: 0.9995488524436951)\nAverage Loss: 0.46919996677150333\n####################################################################################################AUC: 0.9041175359032503\nBest Accuracy: 0.8798412698412699 (Threshold: 0.13584205501212096)\nTPR at FPR=10.0%: 0.7011904761904761 (Score: 0.9583638906478882)\nTPR at FPR=1.0%: 0.24841269841269842 (Score: 0.9993153810501099)\nAverage Loss: 0.40308997611149433\n####################################################################################################AUC: 0.8985135582010583\nBest Accuracy: 0.8792063492063492 (Threshold: 0.0554036657163878)\nTPR at FPR=10.0%: 0.6716269841269841 (Score: 0.9639698266983032)\nTPR at FPR=1.0%: 0.2152777777777778 (Score: 0.9995352029800415)\nAverage Loss: 0.5122716583750035\n####################################################################################################AUC: 0.9058038863693626\nBest Accuracy: 0.8850793650793651 (Threshold: 0.13075093828225343)\nTPR at FPR=10.0%: 0.7218253968253968 (Score: 0.9681994915008545)\nTPR at FPR=1.0%: 0.2623015873015873 (Score: 0.999649167060852)\nAverage Loss: 0.43187202120434404\n####################################################################################################AUC: 0.8971601788863695\nBest Accuracy: 0.8804761904761905 (Threshold: 0.045349182414935775)\nTPR at FPR=10.0%: 0.6609126984126984 (Score: 0.9611561894416809)\nTPR at FPR=1.0%: 0.2396825396825397 (Score: 0.9993257522583008)\nAverage Loss: 0.5295804329411491\n####################################################################################################AUC: 0.9008590010078106\nBest Accuracy: 0.8763492063492063 (Threshold: 0.40709965036355367)\nTPR at FPR=10.0%: 0.7013888888888888 (Score: 0.9943603873252869)\nTPR at FPR=1.0%: 0.28115079365079365 (Score: 0.9997544884681702)\nAverage Loss: 0.43053449967426666\n####################################################################################################AUC: 0.9007028691106073\nBest Accuracy: 0.8792063492063492 (Threshold: 0.11574143740985)\nTPR at FPR=10.0%: 0.703968253968254 (Score: 0.9931934475898743)\nTPR at FPR=1.0%: 0.20615079365079364 (Score: 0.9997492432594299)\nAverage Loss: 0.45830285182233954\n####################################################################################################AUC: 0.8918712207105064\nBest Accuracy: 0.871904761904762 (Threshold: 0.1257552100813786)\nTPR at FPR=10.0%: 0.6880952380952381 (Score: 0.9957960844039917)\nTPR at FPR=1.0%: 0.25416666666666665 (Score: 0.9998204112052917)\nAverage Loss: 0.5295193400679405\n####################################################################################################AUC: 0.8913471592340638\nBest Accuracy: 0.8784126984126984 (Threshold: 0.33674659807616425)\nTPR at FPR=10.0%: 0.6422619047619048 (Score: 0.9985455274581909)\nTPR at FPR=1.0%: 0.24285714285714285 (Score: 0.9997971653938293)\nAverage Loss: 0.47251886236014523\n####################################################################################################AUC: 0.9127528187200807\nBest Accuracy: 0.8819047619047619 (Threshold: 0.20754144801032828)\nTPR at FPR=10.0%: 0.7325396825396825 (Score: 0.9406241774559021)\nTPR at FPR=1.0%: 0.3759920634920635 (Score: 0.9955971837043762)\nAverage Loss: 0.32738428150521126\n####################################################################################################AUC: 0.9102273872511968\nBest Accuracy: 0.8811111111111111 (Threshold: 0.21728508113579234)\nTPR at FPR=10.0%: 0.7158730158730159 (Score: 0.9340832829475403)\nTPR at FPR=1.0%: 0.3998015873015873 (Score: 0.9958102703094482)\nAverage Loss: 0.3444241999582326\n####################################################################################################AUC: 0.9138169249181154\nBest Accuracy: 0.8822222222222222 (Threshold: 0.1270105143665013)\nTPR at FPR=10.0%: 0.7426587301587302 (Score: 0.9135438203811646)\nTPR at FPR=1.0%: 0.4218253968253968 (Score: 0.9957913160324097)\nAverage Loss: 0.36578153728431817\n####################################################################################################AUC: 0.9142961073318218\nBest Accuracy: 0.8849206349206349 (Threshold: 0.2570748374511923)\nTPR at FPR=10.0%: 0.7301587301587301 (Score: 0.9568803310394287)\nTPR at FPR=1.0%: 0.43353174603174605 (Score: 0.9967682361602783)\nAverage Loss: 0.36079730476172367\n####################################################################################################AUC: 0.9118335065507686\nBest Accuracy: 0.8817460317460317 (Threshold: 0.3071371658422034)\nTPR at FPR=10.0%: 0.7267857142857143 (Score: 0.9777267575263977)\nTPR at FPR=1.0%: 0.45575396825396824 (Score: 0.9973113536834717)\nAverage Loss: 0.3544913220903623\n####################################################################################################AUC: 0.915317224111867\nBest Accuracy: 0.8812698412698413 (Threshold: 0.14671899654303475)\nTPR at FPR=10.0%: 0.7494047619047619 (Score: 0.9343666434288025)\nTPR at FPR=1.0%: 0.4027777777777778 (Score: 0.9977713823318481)\nAverage Loss: 0.3866839759710108\n####################################################################################################AUC: 0.9138940066767448\nBest Accuracy: 0.8811111111111111 (Threshold: 0.1215629355189854)\nTPR at FPR=10.0%: 0.7525793650793651 (Score: 0.9413536190986633)\nTPR at FPR=1.0%: 0.4005952380952381 (Score: 0.9978323578834534)\nAverage Loss: 0.3838400870690425\n####################################################################################################AUC: 0.9158619142101285\nBest Accuracy: 0.8819047619047619 (Threshold: 0.1816996639657616)\nTPR at FPR=10.0%: 0.7444444444444445 (Score: 0.9434211850166321)\nTPR at FPR=1.0%: 0.40793650793650793 (Score: 0.997951328754425)\nAverage Loss: 0.3850083784537087\n####################################################################################################AUC: 0.9120400289745527\nBest Accuracy: 0.8815873015873016 (Threshold: 0.2168000961569648)\nTPR at FPR=10.0%: 0.7331349206349206 (Score: 0.9838338494300842)\nTPR at FPR=1.0%: 0.37936507936507935 (Score: 0.998543381690979)\nAverage Loss: 0.36806364380399964\n####################################################################################################AUC: 0.9090080939783322\nBest Accuracy: 0.88 (Threshold: 0.08131052524169635)\nTPR at FPR=10.0%: 0.7238095238095238 (Score: 0.9731644988059998)\nTPR at FPR=1.0%: 0.35138888888888886 (Score: 0.9985748529434204)\nAverage Loss: 0.41489630684949136\n####################################################################################################AUC: 0.9134412005542958\nBest Accuracy: 0.883968253968254 (Threshold: 0.17156062097213787)\nTPR at FPR=10.0%: 0.7331349206349206 (Score: 0.9726763367652893)\nTPR at FPR=1.0%: 0.38551587301587303 (Score: 0.9986814856529236)\nAverage Loss: 0.39861634454801015\n####################################################################################################AUC: 0.9126887282690853\nBest Accuracy: 0.8826984126984126 (Threshold: 0.3019230767371409)\nTPR at FPR=10.0%: 0.7279761904761904 (Score: 0.9891262054443359)\nTPR at FPR=1.0%: 0.3601190476190476 (Score: 0.9989126920700073)\nAverage Loss: 0.3922838481644826\n####################################################################################################AUC: 0.9078669847568657\nBest Accuracy: 0.8807936507936508 (Threshold: 0.30185993131420963)\nTPR at FPR=10.0%: 0.7170634920634921 (Score: 0.9934865236282349)\nTPR at FPR=1.0%: 0.3327380952380952 (Score: 0.9990170001983643)\nAverage Loss: 0.416769449960152\n####################################################################################################AUC: 0.9052425831443689\nBest Accuracy: 0.8809523809523809 (Threshold: 0.3770884994309789)\nTPR at FPR=10.0%: 0.703968253968254 (Score: 0.9969133138656616)\nTPR at FPR=1.0%: 0.3051587301587302 (Score: 0.9991865754127502)\nAverage Loss: 0.4331670764465533\n####################################################################################################AUC: 0.9116062767699672\nBest Accuracy: 0.8807936507936508 (Threshold: 0.11609432500454799)\nTPR at FPR=10.0%: 0.7152777777777778 (Score: 0.9900914430618286)\nTPR at FPR=1.0%: 0.3503968253968254 (Score: 0.9990679621696472)\nAverage Loss: 0.4059253654547324\n####################################################################################################AUC: 0.909878511589821\nBest Accuracy: 0.8838095238095238 (Threshold: 0.10603247603914004)\nTPR at FPR=10.0%: 0.7170634920634921 (Score: 0.986299455165863)\nTPR at FPR=1.0%: 0.3521825396825397 (Score: 0.9991635084152222)\nAverage Loss: 0.47194165713981096\n####################################################################################################AUC: 0.9116446208112876\nBest Accuracy: 0.8836507936507937 (Threshold: 0.25656318423296115)\nTPR at FPR=10.0%: 0.7313492063492063 (Score: 0.9949676394462585)\nTPR at FPR=1.0%: 0.37817460317460316 (Score: 0.9992499947547913)\nAverage Loss: 0.4292217055991585\n####################################################################################################AUC: 0.9078786375661376\nBest Accuracy: 0.8836507936507937 (Threshold: 0.1410691703695467)\nTPR at FPR=10.0%: 0.7085317460317461 (Score: 0.993381142616272)\nTPR at FPR=1.0%: 0.30813492063492065 (Score: 0.999306321144104)\nAverage Loss: 0.4514668861100909\n####################################################################################################AUC: 0.9030014644746788\nBest Accuracy: 0.8812698412698413 (Threshold: 0.32175635900079325)\nTPR at FPR=10.0%: 0.683531746031746 (Score: 0.9981260895729065)\nTPR at FPR=1.0%: 0.298015873015873 (Score: 0.9993946552276611)\nAverage Loss: 0.47036494121964795\n####################################################################################################AUC: 0.9079246189216428\nBest Accuracy: 0.8819047619047619 (Threshold: 0.2916260416192136)\nTPR at FPR=10.0%: 0.7218253968253968 (Score: 0.9950783252716064)\nTPR at FPR=1.0%: 0.3113095238095238 (Score: 0.9993897676467896)\nAverage Loss: 0.4342052312062576\n####################################################################################################AUC: 0.9108625440917106\nBest Accuracy: 0.8853968253968254 (Threshold: 0.04057335027773752)\nTPR at FPR=10.0%: 0.7077380952380953 (Score: 0.9875094294548035)\nTPR at FPR=1.0%: 0.3238095238095238 (Score: 0.999352753162384)\nAverage Loss: 0.49118314668693375\n####################################################################################################AUC: 0.9123256802721089\nBest Accuracy: 0.88 (Threshold: 0.19828409675377698)\nTPR at FPR=10.0%: 0.7353174603174604 (Score: 0.8316733241081238)\nTPR at FPR=1.0%: 0.36507936507936506 (Score: 0.9861753582954407)\nAverage Loss: 0.3262811797141537\n####################################################################################################AUC: 0.9141698948097758\nBest Accuracy: 0.8820634920634921 (Threshold: 0.19807075297263402)\nTPR at FPR=10.0%: 0.7535714285714286 (Score: 0.8300660252571106)\nTPR at FPR=1.0%: 0.36884920634920637 (Score: 0.9891144037246704)\nAverage Loss: 0.32706290029395885\n####################################################################################################AUC: 0.9138121220710507\nBest Accuracy: 0.8823809523809524 (Threshold: 0.18788503558754022)\nTPR at FPR=10.0%: 0.7436507936507937 (Score: 0.8511702418327332)\nTPR at FPR=1.0%: 0.3825396825396825 (Score: 0.9907612204551697)\nAverage Loss: 0.32802260290183655\n####################################################################################################AUC: 0.914654903628118\nBest Accuracy: 0.883015873015873 (Threshold: 0.21272393984804353)\nTPR at FPR=10.0%: 0.7424603174603175 (Score: 0.8644278049468994)\nTPR at FPR=1.0%: 0.4033730158730159 (Score: 0.9910193085670471)\nAverage Loss: 0.3317383197679889\n####################################################################################################AUC: 0.9151813271604937\nBest Accuracy: 0.8817460317460317 (Threshold: 0.27262781311795303)\nTPR at FPR=10.0%: 0.7418650793650794 (Score: 0.8735067844390869)\nTPR at FPR=1.0%: 0.4005952380952381 (Score: 0.9921466708183289)\nAverage Loss: 0.32901159345366177\n####################################################################################################AUC: 0.9151310153691107\nBest Accuracy: 0.8819047619047619 (Threshold: 0.13756239538235432)\nTPR at FPR=10.0%: 0.7498015873015873 (Score: 0.8710891008377075)\nTPR at FPR=1.0%: 0.4158730158730159 (Score: 0.9922192096710205)\nAverage Loss: 0.33308415401604047\n####################################################################################################"
  },
  {
    "path": "applications/deepfake_detection/sequence/models/GaussianSmoothing.py",
    "content": "# author: Hierarchical Fine-Grained Image Forgery Detection and Localization, CVPR2023\nimport os\nimport math\nimport numbers\nimport random\nimport numpy as np\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nclass GaussianSmoothing(nn.Module):\n    \"\"\"\n    Apply gaussian smoothing on a\n    1d, 2d or 3d tensor. Filtering is performed seperately for each channel\n    in the input using a depthwise convolution.\n    Arguments:\n    channels (int, sequence): Number of channels of the input tensors. Output will\n    have this number of channels as well.\n    kernel_size (int, sequence): Size of the gaussian kernel.\n    sigma (float, sequence): Standard deviation of the gaussian kernel.\n    dim (int, optional): The number of dimensions of the data.\n    Default value is 2 (spatial).\n    \"\"\"\n    def __init__(self, channels, kernel_size, sigma, dim=2):\n        super(GaussianSmoothing, self).__init__()\n        if isinstance(kernel_size, numbers.Number):\n            kernel_size = [kernel_size] * dim\n        if isinstance(sigma, numbers.Number):\n            sigma = [sigma] * dim\n\n        # The gaussian kernel is the product of the\n        # gaussian function of each dimension.\n        kernel = 1\n        meshgrids = torch.meshgrid(\n            [\n                torch.arange(size, dtype=torch.float32)\n                for size in kernel_size\n            ], indexing='ij'\n        )\n        for size, std, mgrid in zip(kernel_size, sigma, meshgrids):\n            mean = (size - 1) / 2\n            kernel *= 1 / (std * math.sqrt(2 * math.pi)) * \\\n                      torch.exp(-((mgrid - mean) / std) ** 2 / 2)\n\n        # Make sure sum of values in gaussian kernel equals 1.\n        kernel = kernel / torch.sum(kernel)\n\n        # Reshape to depthwise convolutional weight\n        kernel = kernel.view(1, 1, *kernel.size())\n        kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1))\n\n        self.register_buffer('weight', kernel)\n        self.groups = channels\n\n        if dim == 1:\n            self.conv = F.conv1d\n        elif dim == 2:\n            self.conv = F.conv2d\n        elif dim == 3:\n            self.conv = F.conv3d\n        else:\n            raise RuntimeError(\n                'Only 1, 2 and 3 dimensions are supported. Received {}.'.format(dim)\n            )\n\n    def forward(self, input):\n        \"\"\"\n        Apply gaussian filter to input.\n        Arguments:\n        input (torch.Tensor): Input to apply gaussian filter on.\n        Returns:\n        filtered (torch.Tensor): Filtered output.\n        \"\"\"\n        return self.conv(input, weight=self.weight, groups=self.groups)\n"
  },
  {
    "path": "applications/deepfake_detection/sequence/models/HiFiNet_deepfake.py",
    "content": "# coding: utf-8\n# author: Hierarchical Fine-Grained Image Forgery Detection and Localization\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport sys\nsys.path.append('./sequence/models')\nfrom hrnet.seg_hrnet_config import get_cfg_defaults\nfrom hrnet.seg_hrnet import get_seg_model\n\nclass Flatten(nn.Module):\n    def __init__(self):\n        super(Flatten, self).__init__()\n        \n    def forward(self, x):\t\t\n        return x.view(x.size(0), -1)\n\nclass CatDepth(nn.Module):\n    def __init__(self):\n        super(CatDepth, self).__init__()\n\n    def forward(self, x, y):\n        return torch.cat([x,y],dim=1)\n\nclass HiFiNet_deepfake(nn.Module):\n    def __init__(self, use_laplacian=False, drop_rate=0.5, use_magic_loss=True,\n                 feat_dim = 1024, pretrained=True,\n                 rnn_type='LSTM', rnn_hidden_size=10, num_rnn_layers=1, rnn_drop_rate=0.5,\n                 bidir=False, merge_mode='concat',gate_type='sigmoid', device='cuda'):\n        super(HiFiNet_deepfake, self).__init__()\n        self.use_laplacian = use_laplacian\n        self.feat_dim = feat_dim\n        self.rnn_type = rnn_type\n        self.rnn_input_size = feat_dim\n        self.rnn_hidden_size = rnn_hidden_size\n        self.num_rnn_layers = num_rnn_layers\n        self.rnn_drop_rate = rnn_drop_rate\n        self.bidir = bidir\n        self.magic_loss = use_magic_loss\n        \n        self.device = device\n        \n        self.FENet = get_seg_model(get_cfg_defaults()).to(self.device)\n        self.rnn = nn.LSTM(input_size=self.rnn_input_size, hidden_size=self.rnn_hidden_size,\n                            num_layers=self.num_rnn_layers, batch_first=False, dropout=self.rnn_drop_rate,\n                            bidirectional=self.bidir\n                            )\n        self.output_rnn = nn.Sequential(nn.ReLU(inplace=True),\n                                        nn.Linear(256, 2))\n\n        # Select the merger function\n        if merge_mode == 'concat':\n            self.merger_function = merge_concat\n        elif merge_mode == 'sum':\n            self.merger_function = merge_sum\n        \n    def forward(self,x):\n        batch_size, window_size, _, H, W = x.size()\n        x = x.view(batch_size * window_size, 3, H, W) # Input for RGB branch\n\n        conv_feat = self.FENet(x)\n        z = conv_feat.view(batch_size, window_size, -1).permute(1,0,2)\n        out, (h,c) = self.rnn(z)\n        out = self.merger_function(out[-1, :, :self.rnn_hidden_size], out[0, :, self.rnn_hidden_size:]) \n        out = self.output_rnn(out)\n\n        return out\n\n    def up (self,x, size):\n        return F.interpolate(x,size=size,mode='bilinear',\n                             align_corners=False)\n    def up_pix(self,x,r):\n        return F.pixel_shuffle(x,r)\n\n## Functions to merger the bidirectional outputs\n# Concatenation function\ndef merge_concat(out1, out2):\n    return torch.cat((out1, out2), 1)\n# Summation function\ndef merge_sum(out1, out2):\n    return torch.add(out1, out2)\n\nif __name__ == \"__main__\":\n    import torch\n    input = torch.randn((4, 1, 3, 224, 224)).cuda()  # [64, 10, 3, 224, 224]\n    model = HiFiNet_deepfake(use_laplacian=True, drop_rate=0.2, use_magic_loss=False,\n                            pretrained=True, rnn_drop_rate=0.2, feat_dim=1000,\n                            rnn_hidden_size=128, num_rnn_layers=2,\n                            bidir=True).cuda()\n    model = torch.nn.DataParallel(model)\n\n    print(f\"...comes to this place...\")\n    output = model(input)\n    print(f\"the model output: \", output.size())\n    print(\"...over...\")"
  },
  {
    "path": "applications/deepfake_detection/sequence/models/LaPlacianMs.py",
    "content": "# author: Hierarchical Fine-Grained Image Forgery Detection and Localization, CVPR2023\nimport os\nimport torch\nimport random\nimport numpy as np\nimport torch.nn as nn\nfrom torch.nn import functional as F\ntry:\n    from .GaussianSmoothing import GaussianSmoothing\nexcept:\n    from GaussianSmoothing import GaussianSmoothing\n\nclass LaPlacianMs(nn.Module):\n    def __init__(self,in_c,gauss_ker_size=3,scale=[2],drop_rate=0.2):\n        super(LaPlacianMs, self).__init__()\n        self.scale = scale\n        self.gauss_ker_size = gauss_ker_size\n        ## apply gaussian smoothing to input feature maps with 3 planes\n        ## with kernel size K and sigma s\n        self.smoothing = nn.ModuleDict()\n        for s in self.scale:\n            self.smoothing['scale-'+str(s)] = GaussianSmoothing(in_c, self.gauss_ker_size, s)\n\n        self.conv_1x1 = nn.Sequential(nn.Conv2d(in_c*len(scale), in_c,\n                                                kernel_size=1, stride=1,\n                                                bias=False,groups=1),\n                                                nn.BatchNorm2d(in_c),\n                                                nn.ReLU(inplace=True),\n                                                # nn.Dropout(p=drop_rate)\n                                            )\n        # Official init from torch repo.\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                nn.init.kaiming_normal_(m.weight)\n            elif isinstance(m, nn.BatchNorm2d):\n                nn.init.constant_(m.weight, 1)\n                nn.init.constant_(m.bias, 0)\n            elif isinstance(m, nn.Linear):\n                nn.init.constant_(m.bias, 0)\n\n    def down(self,x,s):\n        return F.interpolate(x,scale_factor=s,\n                             mode='bilinear',\n                             align_corners=False)\n    def up (self,x, size):\n        return F.interpolate(x,size=size,mode='bilinear',align_corners=False)\n\n    def forward(self, x):\n        for i, s in enumerate(self.scale):\n            sm = self.smoothing['scale-'+str(s)](x)\n            sm = self.up(self.down(sm,1/s),(x.shape[2],x.shape[3]))\n            if i == 0:\n                diff = x - sm\n            else:\n                diff = torch.cat((diff, x - sm), dim=1)\n        return self.conv_1x1(diff)\n"
  },
  {
    "path": "applications/deepfake_detection/sequence/models/hrnet/seg_hrnet.py",
    "content": "# ------------------------------------------------------------------------------\n# Copyright (c) Microsoft\n# Licensed under the MIT License.\n# The script is adopted from Ke Sun (sunk@mail.ustc.edu.cn)\n# ------------------------------------------------------------------------------\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\nfrom LaPlacianMs import LaPlacianMs\n\nimport os\nimport logging\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch._utils\nimport torch.nn.functional as F\n\nBN_MOMENTUM = 0.01\nlogger = logging.getLogger(__name__)\n\n# noise generation\ndef srm_generation(image):\n    \"\"\"\n    :param image: N * C * H * W\n    :return: noises\n    \"\"\"\n\n    # srm kernel 1\n    srm1 = np.zeros([5, 5]).astype('float32')\n    srm1[1:-1, 1:-1] = np.array([[-1, 2, -1],\n                                 [2, -4, 2],\n                                 [-1, 2, -1]])\n    srm1 /= 4.\n    # srm kernel 2\n    srm2 = np.array([[-1, 2, -2, 2, -1],\n                     [2, -6, 8, -6, 2],\n                     [-2, 8, -12, 8, -2],\n                     [2, -6, 8, -6, 2],\n                     [-1, 2, -2, 2, -1]]).astype('float32')\n    srm2 /= 12.\n    # srm kernel 3\n    srm3 = np.zeros([5, 5]).astype('float32')\n    srm3[2, 1:-1] = np.array([1, -2, 1])\n    srm3 /= 2.\n\n    srm = np.stack([srm1, srm2, srm3], axis=0)\n\n    W_srm = np.zeros([3, 3, 5, 5]).astype('float32')\n\n    for i in range(3):\n        W_srm[i, 0, :, :] = srm[i, :, :]\n        W_srm[i, 1, :, :] = srm[i, :, :]\n        W_srm[i, 2, :, :] = srm[i, :, :]\n\n    W_srm = torch.from_numpy(W_srm).to(image.get_device())\n\n    srm_noise = F.conv2d(image, W_srm, padding=2)\n\n    return srm_noise\n\n# bayar constrained layer\nclass BayarConstraint(object):\n    def __init__(self):\n        pass\n\n    def __call__(self, module):\n        if hasattr(module, 'weight'):\n            weight = module.weight.data      # oc, ic, h, w\n\n            h, w = weight.size()[2:]\n            mask = torch.zeros_like(weight)\n            mask[:, :, h//2, w//2] = 1\n\n            weight *= (1 - mask)\n            rest_sum = torch.sum(weight, dim=(2, 3), keepdim=True)\n            weight /= (rest_sum + 1e-7)\n            weight -= mask\n            module.weight.data = weight\n\ndef conv3x3(in_planes, out_planes, stride=1):\n    \"\"\"3x3 convolution with padding\"\"\"\n    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,\n                     padding=1, bias=False)\n\nclass CatDepth(nn.Module):\n    def __init__(self):\n        super(CatDepth, self).__init__()\n\n    def forward(self, x, y):\n        return torch.cat([x,y],dim=1)\n\ndef weights_init(init_type='gaussian'):\n    def init_fun(m):\n        classname = m.__class__.__name__\n        if (classname.find('Conv') == 0 or classname.find(\n                'Linear') == 0) and hasattr(m, 'weight'):\n            if init_type == 'gaussian':\n                nn.init.normal_(m.weight, 0.0, 0.02)\n            elif init_type == 'xavier':\n                nn.init.xavier_normal_(m.weight, gain=math.sqrt(2))\n            elif init_type == 'kaiming':\n                nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in')\n            elif init_type == 'orthogonal':\n                nn.init.orthogonal_(m.weight, gain=math.sqrt(2))\n            elif init_type == 'default':\n                pass\n            else:\n                assert 0, \"Unsupported initialization: {}\".format(init_type)\n            if hasattr(m, 'bias') and m.bias is not None:\n                nn.init.constant_(m.bias, 0.0)\n    return init_fun\n\n'''GX: basicblock contains two conv3x3 and two batch norm'''\n'''GX: at last, it has a residual connection'''\nclass BasicBlock(nn.Module):\n    expansion = 1\n\n    def __init__(self, inplanes, planes, stride=1, downsample=None):\n        super(BasicBlock, self).__init__()\n        self.conv1 = conv3x3(inplanes, planes, stride)\n        self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)\n        self.relu = nn.ReLU(inplace=False)\n        self.conv2 = conv3x3(planes, planes)\n        self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)\n        self.downsample = downsample\n        self.stride = stride\n\n    def forward(self, x):\n        residual = x\n\n        out = self.conv1(x)\n        out = self.bn1(out)\n        out = self.relu(out)\n\n        out = self.conv2(out)\n        out = self.bn2(out)\n\n        if self.downsample is not None:\n            residual = self.downsample(x)\n\n        out = out + residual\n        out = self.relu(out)\n\n        return out\n\n'''GX: 3 conv + 3 bn then a residual.'''\nclass Bottleneck(nn.Module):\n    expansion = 4\n\n    def __init__(self, inplanes, planes, stride=1, downsample=None):\n        super(Bottleneck, self).__init__()\n        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)\n        self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)\n        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,\n                               padding=1, bias=False)\n        self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)\n        self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1,\n                               bias=False)\n        self.bn3 = nn.BatchNorm2d(planes * self.expansion,\n                                  momentum=BN_MOMENTUM)\n        self.relu = nn.ReLU(inplace=False)\n        self.downsample = downsample\n        self.stride = stride\n\n    def forward(self, x):\n        residual = x\n\n        out = self.conv1(x)\n        out = self.bn1(out)\n        out = self.relu(out)\n\n        out = self.conv2(out)\n        out = self.bn2(out)\n        out = self.relu(out)\n\n        out = self.conv3(out)\n        out = self.bn3(out)\n\n        if self.downsample is not None:\n            residual = self.downsample(x)\n\n        out = out + residual\n        out = self.relu(out)\n\n        return out\n\n'''GX: the basic component in the network.'''\nclass HighResolutionModule(nn.Module):\n    def __init__(self, num_branches, blocks, num_blocks, num_inchannels,\n                 num_channels, fuse_method, multi_scale_output=True):\n        super(HighResolutionModule, self).__init__()\n        self._check_branches(\n            num_branches, blocks, num_blocks, num_inchannels, num_channels)\n\n        self.num_inchannels = num_inchannels\n        self.fuse_method = fuse_method\n        self.num_branches = num_branches\n\n        self.multi_scale_output = multi_scale_output\n\n        self.branches = self._make_branches(\n            num_branches, blocks, num_blocks, num_channels)\n        self.fuse_layers = self._make_fuse_layers()\n        self.relu = nn.ReLU(inplace=False)\n\n    def _check_branches(self, num_branches, blocks, num_blocks,\n                        num_inchannels, num_channels):\n        if num_branches != len(num_blocks):\n            error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format(\n                num_branches, len(num_blocks))\n            raise ValueError(error_msg)\n\n        if num_branches != len(num_channels):\n            error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format(\n                num_branches, len(num_channels))\n            raise ValueError(error_msg)\n\n        if num_branches != len(num_inchannels):\n            error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format(\n                num_branches, len(num_inchannels))\n            raise ValueError(error_msg)\n\n    def _make_one_branch(self, branch_index, block, num_blocks, num_channels,\n                         stride=1):\n        downsample = None\n        if stride != 1 or \\\n                self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion:\n            downsample = nn.Sequential(\n                nn.Conv2d(self.num_inchannels[branch_index],\n                          num_channels[branch_index] * block.expansion,\n                          kernel_size=1, stride=stride, bias=False),\n                nn.BatchNorm2d(num_channels[branch_index] * block.expansion,\n                               momentum=BN_MOMENTUM),\n            )\n\n        layers = []\n        layers.append(block(self.num_inchannels[branch_index],\n                            num_channels[branch_index], stride, downsample))\n        self.num_inchannels[branch_index] = \\\n            num_channels[branch_index] * block.expansion\n        for i in range(1, num_blocks[branch_index]):\n            layers.append(block(self.num_inchannels[branch_index],\n                                num_channels[branch_index]))\n\n        return nn.Sequential(*layers)\n\n    def _make_branches(self, num_branches, block, num_blocks, num_channels):\n        branches = []\n\n        for i in range(num_branches):\n            branches.append(\n                self._make_one_branch(i, block, num_blocks, num_channels))\n\n        return nn.ModuleList(branches)\n\n    ## GX: fuse layer converts feature maps at different resolution branches\n    ## GX: into the feature map of the new branches' feature map.\n    ## GX: https://zhuanlan.zhihu.com/p/335333233\n    def _make_fuse_layers(self):\n        if self.num_branches == 1:\n            return None\n\n        num_branches = self.num_branches\n        num_inchannels = self.num_inchannels\n        fuse_layers = []\n        for i in range(num_branches if self.multi_scale_output else 1):\n            fuse_layer = []\n            for j in range(num_branches):\n                if j > i:\n                    fuse_layer.append(nn.Sequential(\n                        nn.Conv2d(num_inchannels[j], num_inchannels[i], 1, 1, 0, bias=False),\n                        nn.BatchNorm2d(num_inchannels[i], momentum=BN_MOMENTUM)))\n                elif j == i:\n                    fuse_layer.append(None)\n                else:\n                    conv3x3s = []\n                    for k in range(i - j):\n                        if k == i - j - 1:\n                            num_outchannels_conv3x3 = num_inchannels[i]\n                            conv3x3s.append(nn.Sequential(\n                                nn.Conv2d(num_inchannels[j],\n                                          num_outchannels_conv3x3,\n                                          3, 2, 1, bias=False),\n                                nn.BatchNorm2d(num_outchannels_conv3x3,\n                                               momentum=BN_MOMENTUM)))\n                        else:\n                            num_outchannels_conv3x3 = num_inchannels[j]\n                            conv3x3s.append(nn.Sequential(\n                                nn.Conv2d(num_inchannels[j],\n                                          num_outchannels_conv3x3,\n                                          3, 2, 1, bias=False),\n                                nn.BatchNorm2d(num_outchannels_conv3x3,\n                                               momentum=BN_MOMENTUM),\n                                nn.ReLU(inplace=False)))\n                    fuse_layer.append(nn.Sequential(*conv3x3s))\n            fuse_layers.append(nn.ModuleList(fuse_layer))\n\n        return nn.ModuleList(fuse_layers)\n\n    def get_num_inchannels(self):\n        return self.num_inchannels\n\n    def forward(self, x):\n        if self.num_branches == 1:\n            return [self.branches[0](x[0])]\n\n        for i in range(self.num_branches):\n            x[i] = self.branches[i](x[i])\n\n        x_fuse = []\n        for i in range(len(self.fuse_layers)):\n            y = x[0] if i == 0 else self.fuse_layers[i][0](x[0])\n            for j in range(1, self.num_branches):\n                if i == j:\n                    y = y + x[j]\n                elif j > i:\n                    width_output = x[i].shape[-1]\n                    height_output = x[i].shape[-2]\n                    y = y + F.interpolate(\n                        self.fuse_layers[i][j](x[j]),\n                        size=[height_output, width_output],\n                        mode='bilinear', align_corners=True)\n                else:\n                    y = y + self.fuse_layers[i][j](x[j])\n            x_fuse.append(self.relu(y))\n\n        return x_fuse\n\n\nblocks_dict = {\n    'BASIC': BasicBlock,\n    'BOTTLENECK': Bottleneck\n}\n\n## GX: the HighResolutionNet has 4 stages. \n## GX: each stage has one module which is HighResolutionModule.\n## GX: HighResolutionModule has 1,2,3,4 branches.\n## GX: each stage has a transitional layers in between.\nclass HighResolutionNet(nn.Module):\n\n    def __init__(self, config, **kwargs):\n        super(HighResolutionNet, self).__init__()\n\n        # noise conv\n        # self.im_conv = nn.Conv2d(3, 10, kernel_size=3, stride=1, padding=1, bias=False)\n        # self.bayar_conv = nn.Conv2d(3, 3, kernel_size=5, stride=1, padding=2, bias=False)\n        # self.constraints = BayarConstraint()\n\n        # stem net\n        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)\n        self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)\n        self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False)\n        self.bn2 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)\n        self.relu = nn.ReLU(inplace=False)\n\n        # # frequency branch\n        # self.conv1fre = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)\n        # self.bn1fre = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)\n        # self.conv2fre = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False)\n        # self.bn2fre = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)\n        # self.laplacian = LaPlacianMs(in_c=64,gauss_ker_size=3,scale=[2,4,8])\n\n        # concat\n        self.concat_depth = CatDepth()\n        self.conv_1x1_merge = nn.Sequential(nn.Conv2d(128, 64,\n                                                  kernel_size=1, stride=1,\n                                                  bias=False,groups=2),\n                                        nn.BatchNorm2d(64),\n                                        nn.ReLU(inplace=True),\n                                        nn.Dropout(p=0.2)\n                                       )\n        self.conv_1x1_merge.apply(weights_init('kaiming'))\n\n        self.stage1_cfg = config['STAGE1']\n        num_channels = self.stage1_cfg['NUM_CHANNELS'][0]\n        block = blocks_dict[self.stage1_cfg['BLOCK']]\n        num_blocks = self.stage1_cfg['NUM_BLOCKS'][0]\n        self.layer1 = self._make_layer(block, 64, num_channels, num_blocks)\n        stage1_out_channel = block.expansion * num_channels\n\n        self.stage2_cfg = config['STAGE2']\n        num_channels = self.stage2_cfg['NUM_CHANNELS']\n        block = blocks_dict[self.stage2_cfg['BLOCK']]\n        num_channels = [\n            num_channels[i] * block.expansion for i in range(len(num_channels))]\n        self.transition1 = self._make_transition_layer(\n            [stage1_out_channel], num_channels)\n        self.stage2, pre_stage_channels = self._make_stage(\n            self.stage2_cfg, num_channels)\n\n        self.stage3_cfg = config['STAGE3']\n        num_channels = self.stage3_cfg['NUM_CHANNELS']\n        block = blocks_dict[self.stage3_cfg['BLOCK']]\n        num_channels = [\n            num_channels[i] * block.expansion for i in range(len(num_channels))]\n        self.transition2 = self._make_transition_layer(\n            pre_stage_channels, num_channels)\n        self.stage3, pre_stage_channels = self._make_stage(\n            self.stage3_cfg, num_channels)\n\n        self.stage4_cfg = config['STAGE4']\n        num_channels = self.stage4_cfg['NUM_CHANNELS']\n        block = blocks_dict[self.stage4_cfg['BLOCK']]\n        num_channels = [\n            num_channels[i] * block.expansion for i in range(len(num_channels))]\n        self.transition3 = self._make_transition_layer(\n            pre_stage_channels, num_channels)\n        self.stage4, pre_stage_channels = self._make_stage(\n            self.stage4_cfg, num_channels, multi_scale_output=True)\n\n        # last_inp_channels = np.int(np.sum(pre_stage_channels))\n\n        # Classification Head\n        self.incre_modules, self.downsamp_modules, \\\n            self.final_layer = self._make_head(pre_stage_channels)\n\n        self.classifier = nn.Linear(2048, 1000)\n\n    def _make_head(self, pre_stage_channels):\n        head_block = Bottleneck\n        head_channels = [32, 64, 128, 256]\n\n        # Increasing the #channels on each resolution \n        # from C, 2C, 4C, 8C to 128, 256, 512, 1024\n        incre_modules = []\n        for i, channels  in enumerate(pre_stage_channels):\n            incre_module = self._make_layer(head_block,\n                                            channels,\n                                            head_channels[i],\n                                            1,\n                                            stride=1)\n            incre_modules.append(incre_module)\n        incre_modules = nn.ModuleList(incre_modules)\n            \n        # downsampling modules\n        downsamp_modules = []\n        for i in range(len(pre_stage_channels)-1):\n            in_channels = head_channels[i] * head_block.expansion\n            out_channels = head_channels[i+1] * head_block.expansion\n\n            downsamp_module = nn.Sequential(\n                nn.Conv2d(in_channels=in_channels,\n                          out_channels=out_channels,\n                          kernel_size=3,\n                          stride=2,\n                          padding=1),\n                nn.BatchNorm2d(out_channels, momentum=BN_MOMENTUM),\n                nn.ReLU(inplace=True)\n            )\n\n            downsamp_modules.append(downsamp_module)\n        downsamp_modules = nn.ModuleList(downsamp_modules)\n\n        final_layer = nn.Sequential(\n            nn.Conv2d(\n                in_channels=head_channels[3] * head_block.expansion,\n                out_channels=2048,\n                kernel_size=1,\n                stride=1,\n                padding=0\n            ),\n            nn.BatchNorm2d(2048, momentum=BN_MOMENTUM),\n            nn.ReLU(inplace=True)\n        )\n\n        return incre_modules, downsamp_modules, final_layer\n\n    ## GX: one dimension matrix converts pre to pos.\n    ## GX: if channel numbers are equal, pass it directly.\n    ## GX: if channel numbers are different, using conv 3x3.\n    ## GX: https://zhuanlan.zhihu.com/p/335333233\n    def _make_transition_layer(self, num_channels_pre_layer, num_channels_cur_layer):\n        num_branches_cur = len(num_channels_cur_layer)\n        num_branches_pre = len(num_channels_pre_layer)\n\n        transition_layers = []\n        for i in range(num_branches_cur):\n            if i < num_branches_pre:\n                if num_channels_cur_layer[i] != num_channels_pre_layer[i]:\n                    transition_layers.append(nn.Sequential(\n                        nn.Conv2d(num_channels_pre_layer[i],\n                                  num_channels_cur_layer[i],\n                                  3,\n                                  1,\n                                  1,\n                                  bias=False),\n                        nn.BatchNorm2d(\n                            num_channels_cur_layer[i], momentum=BN_MOMENTUM),\n                        nn.ReLU(inplace=False)))\n                else:\n                    transition_layers.append(None)\n            else:\n                conv3x3s = []\n                for j in range(i + 1 - num_branches_pre):\n                    inchannels = num_channels_pre_layer[-1]\n                    outchannels = num_channels_cur_layer[i] \\\n                        if j == i - num_branches_pre else inchannels\n                    conv3x3s.append(nn.Sequential(\n                        nn.Conv2d(\n                            inchannels, outchannels, 3, 2, 1, bias=False),\n                        nn.BatchNorm2d(outchannels, momentum=BN_MOMENTUM),\n                        nn.ReLU(inplace=False)))\n                transition_layers.append(nn.Sequential(*conv3x3s))\n\n        return nn.ModuleList(transition_layers)\n\n    ## GX: _make_layer creates a conv + bn\n    def _make_layer(self, block, inplanes, planes, blocks, stride=1):\n        downsample = None\n        if stride != 1 or inplanes != planes * block.expansion:\n            downsample = nn.Sequential(\n                nn.Conv2d(inplanes, planes * block.expansion,\n                          kernel_size=1, stride=stride, bias=False),\n                nn.BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM),\n            )\n\n        layers = []\n        layers.append(block(inplanes, planes, stride, downsample))\n        inplanes = planes * block.expansion\n        for i in range(1, blocks):\n            layers.append(block(inplanes, planes))\n\n        return nn.Sequential(*layers)\n\n    def _make_stage(self, layer_config, num_inchannels, multi_scale_output=True):\n        ## GX: num_modules are all 1 in this work.\n        ## GX: light-weight architectures: num_blocks are all 0.\n        ## GX: branch numbers are 2, 3, 4.\n        num_modules = layer_config['NUM_MODULES'] \n        num_branches = layer_config['NUM_BRANCHES']\n        num_blocks = layer_config['NUM_BLOCKS']\n        num_channels = layer_config['NUM_CHANNELS']\n        block = blocks_dict[layer_config['BLOCK']]\n        fuse_method = layer_config['FUSE_METHOD']\n\n        modules = []\n        for i in range(num_modules):\n            # multi_scale_output is only used last module\n            if not multi_scale_output and i == num_modules - 1:\n                reset_multi_scale_output = False\n            else:\n                reset_multi_scale_output = True\n            modules.append(\n                HighResolutionModule(num_branches, block, num_blocks,\n                                     num_inchannels, num_channels, fuse_method,\n                                     reset_multi_scale_output)\n            )\n            num_inchannels = modules[-1].get_num_inchannels()\n\n        return nn.Sequential(*modules), num_inchannels\n\n    def forward(self, x):\n        x = self.conv1(x)\n        x = self.bn1(x)\n        x = self.relu(x)\n        x = self.conv2(x)\n        x = self.bn2(x)\n        x = self.relu(x)\n        x = self.layer1(x)  \n\n        x_list = []\n        for i in range(self.stage2_cfg['NUM_BRANCHES']):\n            if self.transition1[i] is not None:\n                x_list.append(self.transition1[i](x))\n            else:\n                x_list.append(x)\n        y_list = self.stage2(x_list)\n        x_list = []\n        for i in range(self.stage3_cfg['NUM_BRANCHES']):\n            if self.transition2[i] is not None:\n                if i < self.stage2_cfg['NUM_BRANCHES']:\n                    x_list.append(self.transition2[i](y_list[i]))\n                else:\n                    x_list.append(self.transition2[i](y_list[-1]))\n            else:\n                x_list.append(y_list[i])\n        y_list = self.stage3(x_list)\n        x_list = []\n        for i in range(self.stage4_cfg['NUM_BRANCHES']):\n            if self.transition3[i] is not None:\n                if i < self.stage3_cfg['NUM_BRANCHES']:\n                    x_list.append(self.transition3[i](y_list[i]))\n                else:\n                    x_list.append(self.transition3[i](y_list[-1]))\n            else:\n                x_list.append(y_list[i])\n        y_list = self.stage4(x_list)\n\n        # Classification Head\n        y = self.incre_modules[0](y_list[0])\n        for i in range(len(self.downsamp_modules)):\n            y = self.incre_modules[i+1](y_list[i+1]) + \\\n                        self.downsamp_modules[i](y)\n\n        y = self.final_layer(y)\n\n        if torch._C._get_tracing_state():\n            y = y.flatten(start_dim=2).mean(dim=2)\n        else:\n            y = F.avg_pool2d(y, kernel_size=y.size()\n                                 [2:]).view(y.size(0), -1)\n\n        y = self.classifier(y)\n\n        return y\n\n    def init_weights(self, pretrained='',):\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                nn.init.normal_(m.weight, std=0.001)\n            elif isinstance(m, nn.BatchNorm2d):\n                nn.init.constant_(m.weight, 1)\n                nn.init.constant_(m.bias, 0)\n        if os.path.isfile(pretrained):\n            ## GX: official pre-trained dict.\n            pretrained_dict = torch.load(pretrained)    \n            print('=> loading HRNet pretrained model {}'.format(pretrained))\n            model_dict = self.state_dict()      ## GX: the current model.\n            nopretrained_dict = {k: v for k, v in model_dict.items()}\n            pretrained_dict_used = {}\n            \n            for k, v in model_dict.items():\n                pretrained_key = k\n                if pretrained_key not in pretrained_dict.keys():\n                    if 'stage2' in pretrained_key and 'fuse_layers' not in pretrained_key:\n                        if 'branches.2' in pretrained_key:\n                            pretrained_key = pretrained_key.replace('stage2.0.', 'stage3.0.')\n                        elif 'branches.3' in pretrained_key:\n                            pretrained_key = pretrained_key.replace('stage2.0.', 'stage4.0.')\n                    elif 'stage3' in pretrained_key and 'fuse_layers' not in pretrained_key:\n                        pretrained_key = pretrained_key.replace('stage3.0.', 'stage4.0.')\n                    elif 'fre' in pretrained_key:\n                        pretrained_key = pretrained_key.replace('fre', '')\n                if pretrained_key in pretrained_dict.keys():\n                    pretrained_dict_used[k] = pretrained_dict[pretrained_key]\n                    nopretrained_dict.pop(k)\n            print(\"no pretrain dict length is: \", len(nopretrained_dict))  ## GX: how many parameters you need to train on your own.\n            model_dict.update(pretrained_dict_used)\n            self.load_state_dict(model_dict)\n        else:\n            print(f\"{pretrained} does NOT exist.\")\n            print(f\"Please try to load the pre-trained weights of HR-Net.\")\n            import sys;sys.exit(0)\n\ndef get_seg_model(cfg, **kwargs):\n    model = HighResolutionNet(cfg, **kwargs)\n    model.init_weights(cfg.PRETRAINED)\n    return model"
  },
  {
    "path": "applications/deepfake_detection/sequence/models/hrnet/seg_hrnet_config.py",
    "content": "# ------------------------------------------------------------------------------\n# Copyright (c) Microsoft\n# Licensed under the MIT License.\n# The script is adopted from Ke Sun (sunk@mail.ustc.edu.cn)\n# ------------------------------------------------------------------------------\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nfrom yacs.config import CfgNode as CN\n\n# high_resoluton_net related params for segmentation\nHRNET = CN()\nHRNET.PRETRAINED_LAYERS = ['*']\nHRNET.STEM_INPLANES = 64\nHRNET.FINAL_CONV_KERNEL = 1\nHRNET.PRETRAINED = './sequence/models/hrnet/hrnet_w18_small_model_v2.pth'\n\nHRNET.STAGE1 = CN()\nHRNET.STAGE1.NUM_MODULES = 1\nHRNET.STAGE1.NUM_BRANCHES = 1\nHRNET.STAGE1.NUM_BLOCKS = [2]\nHRNET.STAGE1.NUM_CHANNELS = [64]\nHRNET.STAGE1.BLOCK = 'BOTTLENECK'\nHRNET.STAGE1.FUSE_METHOD = 'SUM'\n\nHRNET.STAGE2 = CN()\nHRNET.STAGE2.NUM_MODULES = 1\nHRNET.STAGE2.NUM_BRANCHES = 4\nHRNET.STAGE2.NUM_BLOCKS = [2, 2, 2, 2]\nHRNET.STAGE2.NUM_CHANNELS = [18, 36, 72, 144]\nHRNET.STAGE2.BLOCK = 'BASIC'\nHRNET.STAGE2.FUSE_METHOD = 'SUM'\n\nHRNET.STAGE3 = CN()\nHRNET.STAGE3.NUM_MODULES = 1\nHRNET.STAGE3.NUM_BRANCHES = 4\nHRNET.STAGE3.NUM_BLOCKS = [2, 2, 2, 2]\nHRNET.STAGE3.NUM_CHANNELS = [18, 36, 72, 144]\nHRNET.STAGE3.BLOCK = 'BASIC'\nHRNET.STAGE3.FUSE_METHOD = 'SUM'\n\nHRNET.STAGE4 = CN()\nHRNET.STAGE4.NUM_MODULES = 1\nHRNET.STAGE4.NUM_BRANCHES = 4\nHRNET.STAGE4.NUM_BLOCKS = [2, 2, 2, 2]\nHRNET.STAGE4.NUM_CHANNELS = [18, 36, 72, 144]\nHRNET.STAGE4.BLOCK = 'BASIC'\nHRNET.STAGE4.FUSE_METHOD = 'SUM'\n\n\ndef get_cfg_defaults():\n  \"\"\"Get a yacs CfgNode object with default values for my_project.\"\"\"\n  # Return a clone so that the defaults will not be altered\n  # This is for the \"local variable\" use pattern\n  return HRNET.clone()\n\nif __name__ == \"__main__\":\n  print(\"Hello World!\")\n"
  },
  {
    "path": "applications/deepfake_detection/sequence/models/run_model.sh",
    "content": "source ~/.bashrc\nconda activate HiFi_Net_deepfake\nCUDA_NUM=2\nCUDA_VISIBLE_DEVICES=$CUDA_NUM python HiFiNet_deepfake.py"
  },
  {
    "path": "applications/deepfake_detection/sequence/rnn_stratified_dataloader.py",
    "content": "# coding: utf-8\n# author: Hierarchical Fine-Grained Image Forgery Detection and Localization, CVPR2023\n# based on the sample strategy proposed in Two-branch Recurrent Network for Isolating Deepfakes in Videos, ECCV2020\nimport torch\nimport torchvision\nimport h5py\nimport os\nimport glob\nimport numpy as np\nimport json\nimport numpy as np\n\nfrom torch.utils import data\n\n# Image transformation\ndef get_image_transformation(use_laplacian=False, normalize=True):\n    transforms = []\n    if normalize:\n        transforms.extend(\n                        [torchvision.transforms.ToPILImage(), # Next line takes PIL images as input (ToPILImage() preserves the values in the input array or tensor)\n                         torchvision.transforms.ToTensor(), # To bring the pixel values in the range [0,1]\n                         torchvision.transforms.Normalize(mean=(0.5,0.5,0.5), std=(0.5,0.5,0.5))]\n        )\n\n        return torchvision.transforms.Compose(transforms)\n    else:\n        transforms.extend(\n                        [torchvision.transforms.ToPILImage(), # Next line takes PIL images as input (ToPILImage() preserves the values in the input array or tensor)\n                         torchvision.transforms.ToTensor()] # To bring the pixel values in the range [0,1]\n        )\n        return torchvision.transforms.Compose(transforms)\n\n# Main dataloader \ndef get_dataloader(img_path,train_dataset_names,ctype,manipulations_dict,window_size=10,hop=1,use_laplacian=False,normalize=True,strat_sampling=False,balanced_minibatch=False,mode='train',bs=32,workers=4):\n    \"\"\"\n    This is a dataloader for Face Forensics++ dataset stored in HDF5 file format.\n    \n    The structure of the files should be as shown below:\n    filename.h5 -> keys (video names. Ex, 000_003 for manipulated and 000 for original) -> each video will further have 'n' number of\n    frames. f[key][i] to acces 'ith' frame of 'key' video.\n    Example of filename: FF++_Deepfakes_c40.h5, FF++_Face2Face_c23.h5, FF++_original_c0.h5, etc.\n    \n    Parameters\n    ----------\n    img_path : str\n        The location of h5 files on hard drive.\n    train_dataset_names : list\n        The datasets that are to be loaded.\n        \n    returns\n    -------\n    out: torch.utils.data.dataloader.DataLoader\n        A generator that can be used to get the required batches of sequential\n        samples of data.\n        \n    Examples\n    --------\n    img_path = '/research/cvlshare/cvl-guoxia11/FaceForensics++'\n    train_dataset_names = ['original', 'Deepfakes']\n    ctype = 'c40'\n    manipulations_dict = {0:'Deepfakes',255:'original'}\n    window_size = 10\n    hop = 5\n    use_laplacian = True\n    normalize = True\n    strat_sampling = True\n    mode='train'\n    bs=32\n    workers=0\n    train_generator = get_dataloader(img_path,train_dataset_names,ctype,manipulations_dict,window_size,hop,use_laplacian,normalize,strat_sampling,mode,bs,workers)\n    \"\"\"\n    transform = get_image_transformation(use_laplacian=False, normalize=normalize)\n    params = {'batch_size': bs,\n              'shuffle': (mode=='train'),\n              'num_workers': workers,\n              'drop_last' : (mode=='train')\n            }\n    if mode == 'test' or mode == 'val':\n        strat_sampling = False\n\n    datalist_dict = get_img_list(img_path, train_dataset_names, ctype, mode, window_size, hop, strat_sampling, balanced_minibatch)\n\n    datasets = { dataset_key : ForensicFaceDatasetRNN(img_list, img_path, dataset_key, ctype,\n                                                        manipulations_dict, window_size, hop=hop,\n                                                        use_laplacian=use_laplacian, \n                                                        strat_sampling=strat_sampling,\n                                                        transform=transform)\n                 for dataset_key, img_list in datalist_dict.items()\n                }\n    joined_dataset = data.ConcatDataset([dataset for keys, dataset in datasets.items() ])\n    joined_generator = data.DataLoader(joined_dataset,**params,pin_memory=True)\n    return joined_generator, joined_dataset\n\n\n# Generate a dictionary with \"dataset\": [dataset-video_id-frame_start]\ndef get_img_list(img_path, datasets, ctype, split, window_size, hop, strat_sampling, balanced_minibatch, repeat_num=6):\n    # Get the video_ids based on the split\n    if split == 'train':\n        with open('/research/cvl-guoxia11/deepfake_AIGC/FaceForensics/dataset/splits/train.json', 'r') as f_json:\n            img_folders = json.load(f_json)\n    elif split == 'val':\n        with open('/research/cvl-guoxia11/deepfake_AIGC/FaceForensics/dataset/splits/val.json', 'r') as f_json:\n            img_folders = json.load(f_json)\n    elif split == 'test':\n        with open('/research/cvl-guoxia11/deepfake_AIGC/FaceForensics/dataset/splits/test.json', 'r') as f_json:\n            img_folders = json.load(f_json)\n\n    data_dict = {}\n    for dataset in datasets:\n        data_list = []\n        data_filename = glob.glob(f'{img_path}/*{dataset}*{ctype}*.h5')[0] # Find the correct data file in the img_path\n        f = h5py.File(data_filename, 'r') # Load the data file in f\n        tmp_img_folders = []\n        if dataset == \"original\":\n            tmp_img_folder = [x for sublist in img_folders for x in sublist]\n            if split == 'train' and strat_sampling and balanced_minibatch:\n                for i in range(4*repeat_num):\n                    tmp_img_folders.extend(tmp_img_folder) # Oversample by 4, then it has 2880 sequences.\n            else:\n                tmp_img_folders = tmp_img_folder\n        else:\n            _ = list(map(lambda x:[\"_\".join([x[0],x[1]]),\"_\".join([x[1],x[0]])], img_folders))\n            tmp_img_folder = [x for sublist in _ for x in sublist]\n            if split == 'train' and strat_sampling and balanced_minibatch:\n                for i in range(repeat_num):\n                    tmp_img_folders.extend(tmp_img_folder) # Oversample by 4, then it has 2880 sequences.\n            else:\n                tmp_img_folders = tmp_img_folder\n\n        for folder in tmp_img_folders:\n            if strat_sampling:\n                frame_limit = f[folder].shape[0]\n                if frame_limit > window_size*hop:\n                    ## we record: the dataset name, the video id (folder) and total number of frames (frame_limit)\n                    data_list.append(f'{dataset}-{folder}-{frame_limit}')\n            else:\n                # Get the indices of the starting frame of each chunk of frames\n                if f[folder].shape[0] > window_size*hop:\n                    frame_start_indices = np.arange(0, f[folder].shape[0]-(window_size*hop), window_size*hop)\n                for frame_index in frame_start_indices:\n                    data_list.append(f'{dataset}-{folder}-{frame_index}')\n        f.close()\n        data_dict[dataset] = data_list\n    return data_dict\n\nclass ForensicFaceDatasetRNN(data.Dataset):\n    def __init__(self, list_ids, img_path, dataset_name, ctype, manipulations_dict, window_size, hop, use_laplacian=False, strat_sampling=False, transform=[]):\n        super(ForensicFaceDatasetRNN, self).__init__()\n        self.list_ids = list_ids\n        self.transform = transform\n        self.use_laplacian = use_laplacian\n        self.strat_sampling = strat_sampling\n        self.dataset_name = dataset_name\n        self.dname_to_id = manipulations_dict\n        self.window_size = window_size\n        self.hop = hop\n        self.h5_handler = None\n        self.data_filename = self.get_dbfile_path(f'{img_path}/*{dataset_name}*{ctype}*.h5')\n        if not os.path.exists(self.data_filename):\n            raise RunTimeError('%s not found' % (self.data_filename))\n        if self.hop < 1:\n            raise ValueError(f'Minimum value of hop is 1. And you provided {self.hop}')\n        \n    def __len__(self):\n        return len(self.list_ids)\n    \n    def get_dbfile_path(self,path_pattern):\n        list_files = glob.glob(path_pattern)\n        n_files = len(list_files)\n        if n_files >=2:\n            raise RuntimeError(f'Found multiple files in {path_pattern}')\n        elif n_files == 0:\n            raise RuntimeError(f'Files not found in {path_pattern}')\n        else:\n            return list_files[0]\n        \n    def __getitem__(self, index):\n        if self.h5_handler is None:\n            self.h5_handler = h5py.File(self.data_filename, 'r', swmr=True)\n        file_id = self.list_ids[index].split('-')\n\n        data_folder = file_id[1]\n        if self.strat_sampling:\n            frame_limit = file_id[2]\n            ## now we random sample a frame within the video\n            frame_id = np.random.randint(0,int(frame_limit)-(self.window_size*self.hop))\n        else:\n            frame_id = file_id[2]\n\n        frames = self.h5_handler[data_folder][int(frame_id):int(frame_id)+(self.window_size*self.hop):self.hop]\n\n        ## Now handling the label\n        label = 1.0 if self.dataset_name == \"original\" else 0.0\n            \n        '''\n            ## visualization example:\n            import cv2\n            print(f\"the frames are: \", frames.shape)\n            # output_frames = self.transform(frames)\n            for _ in range(10):\n                frame = frames[_]\n                # print(f\"the frame is: \", frame.shape)\n                # print(\"output frames: \", frame.shape)\n                image_data = frame.astype(np.uint8)\n                image_data = cv2.cvtColor(image_data, cv2.COLOR_BGR2RGB)\n                # cv2.imshow('demo.png', image_data)\n                cv2.imwrite(f'demo_{_}_{self.dataset_name}.png', image_data)\n        '''\n        frames = torch.stack(list(map(self.transform,frames)))\n        image_names = '~'.join([f\"{data_folder}/{int(frame_id) + i * self.hop}\" for i in range(self.window_size)])\n\n        return frames, label, image_names"
  },
  {
    "path": "applications/deepfake_detection/sequence/runjobs_utils.py",
    "content": "# coding: utf-8\n# author: Hierarchical Fine-Grained Image Forgery Detection and Localization\nimport datetime\nimport logging\nimport sys\nimport torch\nimport os\nimport datetime\n\ndef init_logger(name):\n    logger = logging.getLogger(name)\n    h = logging.StreamHandler(sys.stdout)\n    h.flush = sys.stdout.flush\n    logger.addHandler(h)\n    return logger\n\nlogger = init_logger(__name__)\nlogger.setLevel(logging.INFO)\n\ndef torch_load_model(model, optimizer, load_model_path,strict=True):\n    loaded_file = torch.load(load_model_path)\n    model.load_state_dict(loaded_file['model_state_dict'], strict=strict)\n    # model.load_state_dict(loaded_file['model_state_dict'], strict=False)\n    iteration = loaded_file['iter']\n    scheduler = loaded_file['scheduler']\n    epoch = loaded_file['epoch']\n    val_loss = 1.0\n    if 'val_loss' in loaded_file:\n        val_loss = loaded_file['val_loss']\n    # optimizer.load_state_dict(loaded_file['optimizer_state_dict'])    \n    return iteration, epoch, scheduler, val_loss\n\nclass DataConfig(object):\n    def __init__(self, model_path, model_name):\n        self.model_path = model_path\n        self.model_name = model_name\n\nclass Saver(object):\n    def __init__(self, model, optimizer, scheduler, data_config,\n                 starting_time, hours_limit=23, mins_limit=0):\n        self.model = model\n        self.optimizer = optimizer\n        self.scheduler = scheduler\n        self.best_val_loss = sys.maxsize\n        self.data_config = data_config\n        \n        self.hours_limit = hours_limit\n        self.mins_limit = mins_limit\n        self.starting_time = starting_time\n\n    def save_model(self,epoch,ib,val_loss,before_train,best_only=False,force_saving=False):\n        # if (val_loss  <= self.best_val_loss and not(before_train)) or force_saving:\n        if val_loss <= self.best_val_loss or force_saving:\n            ## preserving best_loss\n            if val_loss  <= self.best_val_loss:\n                self.best_val_loss = val_loss\n                \n            if best_only:\n                saving_list = [os.path.join(self.data_config.model_path,'best_model.pth')]\n\n            if force_saving:\n                saving_list = [os.path.join(self.data_config.model_path,'current_model.pth')]\n            print(\"===================================\")\n            print(f\"saving model list is: \", saving_list)\n            print(\"===================================\")\n            for ss in saving_list:\n                torch.save({'epoch': epoch,\n                            'model_state_dict': self.model.state_dict(),\n                            'optimizer_state_dict':\n                            self.optimizer.state_dict() if self.optimizer is not None else None,\n                            'iter' : ib,\n                            'scheduler' : self.scheduler,\n                            'val_loss' : val_loss,\n                            },\n                           ss\n                )\n\n    def check_time(self):\n        this_time = datetime.datetime.now()\n        days, hours, mins = self.days_hours_minutes(\n            this_time - self.starting_time)\n        return days, hours, mins\n\n    def days_hours_minutes(self, td):\n        return td.days, td.seconds//3600, (td.seconds//60) % 60"
  },
  {
    "path": "applications/deepfake_detection/sequence/torch_utils.py",
    "content": "# coding: utf-8\n# author: Hierarchical Fine-Grained Image Forgery Detection and Localization\nimport torch\nimport torch.nn as nn\nfrom tqdm import tqdm\nfrom sklearn import metrics\nimport numpy as np\nfrom runjobs_utils import init_logger\nimport logging\nimport torch.nn.functional as F\nimport os\nfrom collections import OrderedDict\nimport csv\n\nlogger = init_logger(__name__)\nlogger.setLevel(logging.INFO)\n\nclass ROC(object):\n    def __init__(self):\n        self.fpr = None\n        self.tpr = None\n        self.auc = None\n        self.scores = None\n        self.ap_0 = None\n        self.ap_1 = None\n        self.weighted_ap = None\n        \n        self.predictions = []\n        self.gt = []\n        self.best_acc = None\n\n    def get_trunc_auc(self,fpr_value):\n        abs_fpr = np.absolute(self.fpr - fpr_value)\n        idx_min = np.argmin(abs_fpr)\n        area_curve = sum(self.tpr[idx_min])\n        tot_area = sum(np.ones_like(self.tpr)[idx_min])\n        if tot_area == 0:\n            raise ZeroDivisionError('when computing truncated ROC aread')\n        t_auc = area_curve/tot_area\n        return t_auc\n\n    def get_tpr_at_fpr(self,fpr_value):\n        abs_fpr = np.absolute(self.fpr - fpr_value)\n        idx_min = np.argmin(abs_fpr)\n        fpr_value_target = self.fpr[idx_min]\n        idx = np.max(np.where(self.fpr == fpr_value_target))\n        return self.tpr[idx], self.scores[idx]\n        \n    def eval(self):\n        self.fpr, self.tpr, self.scores = metrics.roc_curve(self.gt,self.predictions,drop_intermediate=True)\n        self.auc = metrics.auc(self.fpr,self.tpr)\n\n    def compute_best_accuracy(self,n_samples=200):\n        '''find the best threshold for the accuracy.'''\n        acc_thrs = []\n        min_thr = min(self.predictions)\n        max_thr = max(self.predictions)\n        all_thrs = np.linspace(min_thr,max_thr,n_samples).tolist()\n        for t in all_thrs:\n            acc = self.compute_acc(self.predictions,self.gt,t)\n            acc_thrs.append((t,acc))\n        acc_thrs_arr = np.array(acc_thrs)\n        idx_max = acc_thrs_arr[:,1].argmax()\n        best_thr = acc_thrs_arr[idx_max,0]\n        self.best_acc = acc_thrs_arr[idx_max,1]\n        return best_thr, self.best_acc\n\n    def compute_acc(self,list_scores,list_labels,thr):\n        labels = np.array(list_labels)\n        scores_th = (np.array(list_scores) >= thr).astype(np.int32)\n        acc = (scores_th==labels).sum()/labels.size\n        return acc\n    \n    def get_precision(self,criterion,thr):\n        '''compute the best precision'''\n        pred_labels = []\n        for d in self.predictions:\n            if (d < thr):\n                pred_labels.append(0)\n            elif (d >= thr):\n                pred_labels.append(1)\n        self.ap_0 = metrics.precision_score(self.gt, pred_labels, average='binary', pos_label=0)\n        self.ap_1 = metrics.precision_score(self.gt, pred_labels, average='binary', pos_label=1)\n        self.weighted_ap = metrics.precision_score(self.gt, pred_labels, average='weighted')\n\nclass Metrics(object):\n    def __init__(self):\n        self.tp = 0\n        self.tot_samples = 0\n        self.loss = 0.0\n        self.loss_samples = 0\n        self.roc = ROC()\n        \n        self.best_valid_acc = 0.0\n        self.best_valid_thr = 0.0\n\n        self.tuned_acc_thrs = (0,0)\n        \n    def update(self,tp,loss_value,samples):\n        self.tp+=tp\n        self.tot_samples+=samples\n        self.loss+=loss_value\n        self.loss_samples+=1\n\n    def get_avg_loss(self):\n        if self.loss_samples == 0:\n            raise ZeroDivisionError('not enough sample to avg loss')\n        return self.loss/self.loss_samples\n\ndef count_matching_samples(preds,true_labels,criterion,use_magic_loss=True):\n    acc = 0\n    if use_magic_loss:\n        for l,d in zip(true_labels,preds):\n            if (l == criterion.class_label and d < criterion.R) \\\n            or (l != criterion.class_label and d >= criterion.R):\n                acc += 1\n    else:\n        matching_idx = (preds.argmax(dim=1)==true_labels)\n        acc = matching_idx.sum().item()\n    return acc\n\ndef eval_model(model,dataset_name,valid_joined_generator,criterion,\n               device,desc='valid',val_metrics=None,\n               debug_mode=False):\n    model.eval()\n    print(f\"with the eval model and the debug mode {debug_mode}.\")\n    with torch.no_grad():\n        metrics = Metrics()\n        for jb, val_batch in tqdm(enumerate(valid_joined_generator,1),\n                                  total=len(valid_joined_generator),\n                                  desc=desc):\n            if jb % 8 != 0 and debug_mode:\n                continue\n            ## Getting Input\n            val_img_batch_mmodal, val_true_labels, image_names = val_batch\n            n_samples = val_img_batch_mmodal.shape[0]\n            val_img_batch_mmodal = val_img_batch_mmodal.float().to(device)      \n            val_true_labels = val_true_labels.long().to(device)\n            ## Inference\n            val_preds = model(val_img_batch_mmodal)\n\n            ## Computing loss\n            val_loss = criterion(val_preds, val_true_labels)\n            log_probs = F.softmax(val_preds, dim=-1)\n            res_probs = torch.argmax(log_probs, dim=-1)\n            fixed_labels = 1 - val_true_labels\n                    \n            ## acc/matching_samples. \n            matching_num = count_matching_samples(val_preds,val_true_labels,criterion,use_magic_loss=False)\n            # metrics.roc.predictions.extend(res_probs.tolist())\n            metrics.roc.predictions.extend(log_probs[:,0].tolist())\n            ## Inverting the labels\n            metrics.roc.gt.extend(fixed_labels[:].tolist())\n            metrics.update(matching_num,val_loss.item(),n_samples)\n            \n    ## Getting the Results\n    metrics.roc.eval()\n    print(\"the auc is: %.5f\"%metrics.roc.auc)\n    best_acc = best_thr = None\n    best_thr, best_acc = metrics.roc.compute_best_accuracy()\n    metrics.best_valid_acc = best_acc\n    metrics.best_valid_thr = best_thr\n    print(\"the accuracy is: %.5f: \"%best_acc)\n    print(\"the threshold is: %.5f: \"%best_thr)\n    fpr_values = [0.1,0.01]    \n    for fpr_value in fpr_values:\n        tpr_fpr, score_for_tpr_fpr = metrics.roc.get_tpr_at_fpr(fpr_value)\n        print('tpr_fpr_%.1f: '%(fpr_value*100.0), \"%.5f\"%tpr_fpr)\n    ## Setting the model back to train mode\n    model.train()\n    return metrics\n\ndef display_eval_tb(writer,metrics,tot_iter,desc='test',old_metrics=False):\n    avg_loss = metrics.get_avg_loss()\n    acc = metrics.roc.best_acc\n    auc = metrics.roc.auc\n    writer.add_scalar('%s/loss'%desc, avg_loss, tot_iter)\n    writer.add_scalar('%s/acc'%desc, acc, tot_iter)                      \n    writer.add_scalar('%s/auc'%desc, auc, tot_iter)\n    fpr_values = [0.1,0.01]    \n    for fpr_value in fpr_values:\n        tpr_fpr, score_for_tpr_fpr = metrics.roc.get_tpr_at_fpr(fpr_value)\n        writer.add_scalar('%s/tpr_fpr_%.0f'%(desc,(fpr_value*100.0)), tpr_fpr, tot_iter)\n\ndef train_logging(string, writer, logger, epoch, saver, tot_iter, loss, accu, lr_scheduler):\n    _, hours, mins = saver.check_time()\n    logger.info(\"[Epoch %d] | h:%d m:%d | iteration: %d, loss: %f, accu: %f\", epoch, hours, mins, tot_iter,\n                loss, accu)\n    \n    writer.add_scalar(string, loss, tot_iter )\n    for count, gp in enumerate(lr_scheduler.optimizer.param_groups,1):\n        writer.add_scalar('progress/lr_%d'%count, gp['lr'], tot_iter)\n    writer.add_scalar('progress/epoch', epoch, tot_iter)\n    writer.add_scalar('progress/curr_patience',lr_scheduler.num_bad_epochs,tot_iter)\n    writer.add_scalar('progress/patience',lr_scheduler.patience,tot_iter)\n\nclass lrSched_monitor(object):\n    \"\"\"\n    This class is used to monitor the learning rate scheduler's behavior\n    during training. If the learning rate decreases then this class re-initializes\n    the last best state of the model and starts training from that point of time.\n    \n    Parameters\n    ----------\n    model : torch model\n    scheduler : learning rate scheduler object from training\n    data_config : this object holds model_path and model_name, used to load the last best model.\n    \"\"\"\n    def __init__(self, model, scheduler, data_config):\n        self.model = model\n        self.scheduler = scheduler\n        self.model_name = data_config.model_name\n        self.model_path = data_config.model_path\n        self._last_lr = [0]*len(scheduler.optimizer.param_groups)\n        self.prev_lr_mean = self.get_lr_mean()\n    \n    ## Get the current mean learning rate from the optimizer\n    def get_lr_mean(self):\n        lr_mean = 0\n        for i, grp in enumerate(self.scheduler.optimizer.param_groups):\n            if 'lr' in grp.keys():\n                lr_mean += grp['lr']\n                self._last_lr[i] = grp['lr']\n        return lr_mean/(i+1)       \n        \n    ## This is the function that is to be called right after lr_scheduler.step(val_loss)    \n    def monitor(self):\n        if self.scheduler.num_bad_epochs == self.scheduler.patience:\n            self.prev_lr_mean = self.get_lr_mean()\n        elif self.get_lr_mean() < self.prev_lr_mean:\n            self.load_best_model()\n            self.prev_lr_mean = self.get_lr_mean()\n    \n    ## This function loads the last best model once the learning rate decreases\n    def load_best_model(self):\n        device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n        if torch.cuda.device_count() > 1:\n            ckpt = torch.load(os.path.join(self.model_path,'best_model.pth'))\n            self.model.load_state_dict(ckpt['model_state_dict'], strict=True)\n            self.scheduler.optimizer.load_state_dict(ckpt['optimizer_state_dict'])\n        else:\n            print(f'Loading the best model from {self.model_path}')\n            if device.type == 'cpu':\n                ckpt = torch.load(os.path.join(self.model_path,'best_model.pth'), map_location='cpu')\n            else:\n                ckpt = torch.load(os.path.join(self.model_path,'best_model.pth'))\n            ## Model State Dict\n            state_dict = ckpt['model_state_dict']\n            ## Since the model files are saved on dataparallel we use the below hack to load the weights on a model in cpu or a model on single gpu.\n            keys = state_dict.keys()\n            values = state_dict.values()\n            new_keys = []\n            for key in keys:\n                new_key = key.replace('module.','')    # remove the 'module.'\n                new_keys.append(new_key)\n\n            new_state_dict = OrderedDict(list(zip(new_keys, values))) # create a new OrderedDict with (key, value) pairs\n            self.model.load_state_dict(new_state_dict, strict=True)\n            \n            ## Optimizer State Dict\n            optim_state_dict = ckpt['optimizer_state_dict']\n            # Since the model files are saved on dataparallel we use the below hack to load the optimizer state in cpu or a model on single gpu.\n            keys = optim_state_dict.keys()\n            values = optim_state_dict.values()\n            new_keys = []\n            for key in keys:\n                new_key = key.replace('module.','')    # remove the 'module.'\n                new_keys.append(new_key)\n\n            new_optim_state_dict = OrderedDict(list(zip(new_keys, values))) # create a new OrderedDict with (key, value) pairs\n            self.scheduler.optimizer.load_state_dict(new_optim_state_dict)\n        \n        ## Reduce the learning rate\n        for i, grp in enumerate(self.scheduler.optimizer.param_groups):\n            grp['lr'] = self._last_lr[i]\n"
  },
  {
    "path": "applications/deepfake_detection/test.py",
    "content": "# coding: utf-8\n# author: Hierarchical Fine-Grained Image Forgery Detection and Localization\nimport os\nimport numpy as np\nimport subprocess\nimport logging\nimport sys\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport argparse\nimport datetime\n\nfrom tensorboardX import SummaryWriter\nfrom torch.optim.lr_scheduler import ReduceLROnPlateau\nsource_path = os.path.join('./sequence')\nsys.path.append(source_path)\nfrom rnn_stratified_dataloader import get_dataloader\nfrom models.HiFiNet_deepfake import HiFiNet_deepfake\nfrom torch_utils import eval_model,display_eval_tb,train_logging,lrSched_monitor\nfrom runjobs_utils import init_logger,Saver,DataConfig,torch_load_model\n\nlogger = init_logger(__name__)\nlogger.setLevel(logging.INFO)\n\nstarting_time = datetime.datetime.now()\n\n## Deterministic training\n_seed_id = 100\ntorch.backends.cudnn.deterministic = True\ntorch.manual_seed(_seed_id)\n\ndatasets = ['original', 'Deepfakes', 'FaceSwap', 'NeuralTextures', 'Face2Face']\n# datasets = ['original', 'Deepfakes']\nmanipulations_names = [n for c, n in enumerate(datasets) if n != 'original']\nmanipulations_dict = {n : c  for c, n in enumerate(manipulations_names) }\nmanipulations_dict['original'] = 255\n\nfor key, value in manipulations_dict.items():\n\tprint(key, value)\nctype = 'c40'\n\n# Create the parser\nparser = argparse.ArgumentParser(description='Process some integers.')\nparser.add_argument('--batch_size', type=int, default=4, help='input batch size for training (default: 32)')\nparser.add_argument('--window_size', type=int, default=5, help='size of the sliding window (default: 5)')\nparser.add_argument('--dataset_name', type=str, default=\"FF++\", help='size of the sliding window (default: 5)')\nparser.add_argument('--gpus', type=int, default=4, help='input batch size for training (default: 32)')\nparser.add_argument('--feat_dim', type=int, default=270, help='input dim to rnn. (default: 32)')\nparser.add_argument('--valid_epoch', type=int, default=2, help='val epoch')\nparser.add_argument('--display_step', type=int, default=50, help='display the loss value.')\nparser.add_argument('--learning_rate', type=float, default=1e-3, help='the used learning rate')\n\n# Parse the arguments\nargs = parser.parse_args()\n## Hyper-params #######################\nhparams = {\n            'epochs': 50, 'batch_size': args.batch_size, \n            'basic_lr': args.learning_rate, 'fine_tune': True, 'use_laplacian': True, \n            'step_factor': 0.1, 'patience': 20, 'weight_decay': 1e-06, 'lr_gamma': 2.0, 'use_magic_loss': True, \n            'feat_dim': args.feat_dim, 'drop_rate': 0.2, \n            'skip_valid': False, 'rnn_type': 'LSTM', 'rnn_hidden_size': 256, \n            'num_rnn_layers': 1, 'rnn_drop_rate': 0.2, \n            'bidir': False, 'merge_mode': 'concat', 'perc_margin_1': 0.95, 'perc_margin_2': 0.95, 'soft_boundary': False, \n            'dist_p': 2, 'radius_param': 0.84, 'strat_sampling': True, 'normalize': True, 'window_size': args.window_size, 'hop': 1, \n            'valid_epoch': args.valid_epoch, 'display_step': args.display_step, 'use_sched_monitor': True\n            }\nbatch_size = hparams['batch_size']\nbasic_lr = hparams['basic_lr']\nfine_tune = hparams['fine_tune']\nuse_laplacian = hparams['use_laplacian']\nstep_factor = hparams['step_factor']\npatience = hparams['patience']\nweight_decay = hparams['weight_decay']\nlr_gamma = hparams['lr_gamma']\nuse_magic_loss = hparams['use_magic_loss']\nfeat_dim = hparams['feat_dim']\ndrop_rate = hparams['drop_rate']\nrnn_type = hparams['rnn_type']\nrnn_hidden_size = hparams['rnn_hidden_size']\nnum_rnn_layers = hparams['num_rnn_layers']\nrnn_drop_rate = hparams['rnn_drop_rate']\nbidir = hparams['bidir']\nmerge_mode = hparams['merge_mode']\nperc_margin_1 = hparams['perc_margin_1']\nperc_margin_2 = hparams['perc_margin_2']\ndist_p = hparams['dist_p']\nradius_param = hparams['radius_param']\nstrat_sampling = hparams['strat_sampling']\nnormalize = hparams['normalize']\nwindow_size = hparams['window_size']\nhop = hparams['hop']\nsoft_boundary = hparams['soft_boundary']\nuse_sched_monitor = hparams['use_sched_monitor']\n########################################\nworkers_per_gpu = 6\ndataset_name = f\"{args.dataset_name}\"\n\nexp_name = f\"exp_FF_c40_bs_{batch_size}_lr_{basic_lr}_ws_{window_size}\"\nmodel_name = exp_name\nmodel_path = os.path.join(f'./{dataset_name}', model_name)\nprint(f'Window_size: {args.window_size}; Dataset: {dataset_name}; Batch_Size: {batch_size}; LR: {basic_lr}.')\nprint(f\"the model path is: \", model_path)\n\n## Data Generation\nimg_path = \"/user/guoxia11/cvlshare/cvl-guoxia11/FaceForensics_HiFiNet\"\nbalanced_minibatch_opt = True\n\nif dataset_name == 'FF++':\n    train_generator, train_dataset = get_dataloader(\n                                                img_path, datasets, ctype, manipulations_dict, window_size, hop, \n                                                use_laplacian, normalize, strat_sampling, balanced_minibatch_opt, \n                                                'train', batch_size, workers=workers_per_gpu*args.gpus\n                                                )\n    test_generator, test_dataset = get_dataloader(\n                                                img_path, datasets, ctype, manipulations_dict, window_size, hop, \n                                                use_laplacian, normalize, strat_sampling, False, \n                                                'test', batch_size, workers=workers_per_gpu*args.gpus\n                                                )\n    del train_dataset\n    del test_dataset\nelif dataset_name == \"CelebDF\":        \n    pass    ## TODO: will be released in the near future.\nelif dataset_name == 'DFW':\n    pass    ## TODO: will be released in the near future.\n\n## Model definition\ndevice = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\nmodel = HiFiNet_deepfake(use_laplacian=True, drop_rate=drop_rate, use_magic_loss=False,\n                        pretrained=True, rnn_drop_rate=rnn_drop_rate,\n                        feat_dim=feat_dim, rnn_hidden_size=rnn_hidden_size, \n                        num_rnn_layers=num_rnn_layers,\n                        bidir=bidir)\nmodel = model.to(device)\nmodel = torch.nn.DataParallel(model).cuda()\n\n## Fine-tuning functions\nparams_to_optimize = model.parameters()\noptimizer = torch.optim.Adam(params_to_optimize, lr=basic_lr, weight_decay=weight_decay)\n\nlr_scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=step_factor, min_lr=1e-06, patience=patience, verbose=True)\ncriterion = nn.CrossEntropyLoss()\n\n## Re-loading the model in case\nepoch_init=epoch=ib=ib_off=before_train=0\n# load_model_path = os.path.join(model_path,'best_model.pth')   Not as good as the current_model.pth\nload_model_path = os.path.join(model_path,'current_model.pth')\n\nval_loss = np.inf\nif os.path.exists(load_model_path):\n    logger.info(f'Loading weights, optimizer and scheduler from {load_model_path}...')\n    ib_off, epoch_init, scheduler, val_loss = torch_load_model(model, optimizer, load_model_path)\n\n## Saver object and data config\ndata_config = DataConfig(model_path, model_name)\nsched_monitor = lrSched_monitor(model, lr_scheduler, data_config)\n\n## Start testing\nmetrics = eval_model(model,dataset_name,test_generator,criterion,device,desc='valid',val_metrics=None,debug_mode=False)"
  },
  {
    "path": "applications/deepfake_detection/test.sh",
    "content": "source ~/.bashrc\nconda activate HiFi_Net_deepfake\nCUDA_NUM=\"0,1,3,4,5,6,7\"\nCUDA_VISIBLE_DEVICES=$CUDA_NUM python test.py \\\n                                --dataset_name FF++ \\\n                                --batch_size 32 \\\n                                --window_size 10 \\\n                                --gpus 7 \\\n                                --valid_epoch 1 \\\n                                --feat_dim 1000 \\\n                                --learning_rate 1e-4 \\\n                                --display_step 150\n"
  },
  {
    "path": "applications/deepfake_detection/train.py",
    "content": "# coding: utf-8\n# author: Hierarchical Fine-Grained Image Forgery Detection and Localization\nimport os\nimport numpy as np\nimport subprocess\nimport logging\nimport sys\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport argparse\nimport datetime\n\nfrom tensorboardX import SummaryWriter\nfrom torch.optim.lr_scheduler import ReduceLROnPlateau\nsource_path = os.path.join('./sequence')\nsys.path.append(source_path)\nfrom rnn_stratified_dataloader import get_dataloader\nfrom models.HiFiNet_deepfake import HiFiNet_deepfake\nfrom torch_utils import eval_model,display_eval_tb,train_logging,lrSched_monitor\nfrom runjobs_utils import init_logger,Saver,DataConfig,torch_load_model\n\nlogger = init_logger(__name__)\nlogger.setLevel(logging.INFO)\n\nstarting_time = datetime.datetime.now()\n\n## Deterministic training\n_seed_id = 100\ntorch.backends.cudnn.deterministic = True\ntorch.manual_seed(_seed_id)\n\ndatasets = ['original', 'Deepfakes', 'FaceSwap', 'NeuralTextures', 'Face2Face']\n# datasets = ['original', 'Deepfakes']\nmanipulations_names = [n for c, n in enumerate(datasets) if n != 'original']\nmanipulations_dict = {n : c  for c, n in enumerate(manipulations_names) }\nmanipulations_dict['original'] = 255\n\nfor key, value in manipulations_dict.items():\n\tprint(key, value)\nctype = 'c40'\n\n# Create the parser\nparser = argparse.ArgumentParser(description='Process some integers.')\nparser.add_argument('--batch_size', type=int, default=4, help='input batch size for training (default: 32)')\nparser.add_argument('--window_size', type=int, default=5, help='size of the sliding window (default: 5)')\nparser.add_argument('--dataset_name', type=str, default=\"FF++\", help='size of the sliding window (default: 5)')\nparser.add_argument('--gpus', type=int, default=4, help='input batch size for training (default: 32)')\nparser.add_argument('--feat_dim', type=int, default=270, help='input dim to rnn. (default: 32)')\nparser.add_argument('--valid_epoch', type=int, default=2, help='val epoch')\nparser.add_argument('--display_step', type=int, default=50, help='display the loss value.')\nparser.add_argument('--learning_rate', type=float, default=1e-3, help='the used learning rate')\n\n# Parse the arguments\nargs = parser.parse_args()\n## Hyper-params #######################\nhparams = {\n            'epochs': 50, 'batch_size': args.batch_size, \n            'basic_lr': args.learning_rate, 'fine_tune': True, 'use_laplacian': True, \n            'step_factor': 0.1, 'patience': 20, 'weight_decay': 1e-06, 'lr_gamma': 2.0, 'use_magic_loss': True, \n            'feat_dim': args.feat_dim, 'drop_rate': 0.2, \n            'skip_valid': False, 'rnn_type': 'LSTM', 'rnn_hidden_size': 256, \n            'num_rnn_layers': 1, 'rnn_drop_rate': 0.2, \n            'bidir': False, 'merge_mode': 'concat', 'perc_margin_1': 0.95, 'perc_margin_2': 0.95, 'soft_boundary': False, \n            'dist_p': 2, 'radius_param': 0.84, 'strat_sampling': True, 'normalize': True, 'window_size': args.window_size, 'hop': 1, \n            'valid_epoch': args.valid_epoch, 'display_step': args.display_step, 'use_sched_monitor': True\n            }\nbatch_size = hparams['batch_size']\nbasic_lr = hparams['basic_lr']\nfine_tune = hparams['fine_tune']\nuse_laplacian = hparams['use_laplacian']\nstep_factor = hparams['step_factor']\npatience = hparams['patience']\nweight_decay = hparams['weight_decay']\nlr_gamma = hparams['lr_gamma']\nuse_magic_loss = hparams['use_magic_loss']\nfeat_dim = hparams['feat_dim']\ndrop_rate = hparams['drop_rate']\nrnn_type = hparams['rnn_type']\nrnn_hidden_size = hparams['rnn_hidden_size']\nnum_rnn_layers = hparams['num_rnn_layers']\nrnn_drop_rate = hparams['rnn_drop_rate']\nbidir = hparams['bidir']\nmerge_mode = hparams['merge_mode']\nperc_margin_1 = hparams['perc_margin_1']\nperc_margin_2 = hparams['perc_margin_2']\ndist_p = hparams['dist_p']\nradius_param = hparams['radius_param']\nstrat_sampling = hparams['strat_sampling']\nnormalize = hparams['normalize']\nwindow_size = hparams['window_size']\nhop = hparams['hop']\nsoft_boundary = hparams['soft_boundary']\nuse_sched_monitor = hparams['use_sched_monitor']\n########################################\nworkers_per_gpu = 6\ndataset_name = f\"{args.dataset_name}\"\nexp_name = f\"exp_FF_c40_bs_{batch_size}_lr_{basic_lr}_ws_{window_size}\"\nmodel_name = exp_name\nmodel_path = os.path.join(f'./{dataset_name}', model_name)\nprint(f'Window_size: {args.window_size}; Dataset: {dataset_name}; Batch_Size: {batch_size}; LR: {basic_lr}.')\nos.makedirs('./log', exist_ok=True)\nlog_file_path = f\"log/{exp_name}.txt\"\nwith open(log_file_path, \"a+\") as log_file:\n    log_file.write(\n        f'Dataset Name: {dataset_name} \\n' \n        f'Window_size: {args.window_size}'\n    )\n\n# Create the model path if doesn't exists\nif not os.path.exists(model_path):\n    subprocess.call(f\"mkdir -p {model_path}\", shell=True)\n\n## Data Generation\nimg_path = \"/user/guoxia11/cvlshare/cvl-guoxia11/FaceForensics_HiFiNet\"\nbalanced_minibatch_opt = True\n\nif dataset_name == 'FF++':\n    train_generator, train_dataset = get_dataloader(\n                                                img_path, datasets, ctype, manipulations_dict, window_size, hop, \n                                                use_laplacian, normalize, strat_sampling, balanced_minibatch_opt, \n                                                'train', batch_size, workers=workers_per_gpu*args.gpus\n                                                )\n    test_generator, test_dataset = get_dataloader(\n                                                img_path, datasets, ctype, manipulations_dict, window_size, hop, \n                                                use_laplacian, normalize, strat_sampling, False, \n                                                'test', batch_size, workers=workers_per_gpu*args.gpus\n                                                )\n    # print(\"the dataset length is: \", len(train_dataset))\n    # print(\"the dataloader length is: \", len(train_generator))\n    del train_dataset\n    del test_dataset\nelif dataset_name == \"CelebDF\":        \n    pass    ## TODO: will be released in the near future. \nelif dataset_name == 'DFW':\n    pass    ## TODO: will be released in the near future. \n\n## Model definition\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\nmodel = HiFiNet_deepfake(use_laplacian=True, drop_rate=drop_rate, use_magic_loss=False,\n                        pretrained=True, rnn_drop_rate=rnn_drop_rate,\n                        feat_dim=feat_dim, rnn_hidden_size=rnn_hidden_size, \n                        num_rnn_layers=num_rnn_layers,\n                        bidir=bidir)\nmodel = model.to(device)\nmodel = torch.nn.DataParallel(model).cuda()\n\n## Fine-tuning functions\nparams_to_optimize = model.parameters()\noptimizer = torch.optim.Adam(params_to_optimize, lr=basic_lr, weight_decay=weight_decay)\nlr_scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=step_factor, min_lr=1e-09, patience=patience, verbose=True)\ncriterion = nn.CrossEntropyLoss()\n\n## Re-loading the model in case\nepoch_init=epoch=ib=ib_off=before_train=0\nload_model_path = os.path.join(model_path,'current_model.pth')\nval_loss = np.inf\nif os.path.exists(load_model_path):\n    logger.info(f'Loading weights, optimizer and scheduler from {load_model_path}...')\n    _, _, _, _ = torch_load_model(model, optimizer, load_model_path)\n\n## Saver object and data config\ndata_config = DataConfig(model_path, model_name)\nsaver = Saver(model, optimizer, lr_scheduler, data_config, starting_time, hours_limit=23, mins_limit=0)\nsched_monitor = lrSched_monitor(model, lr_scheduler, data_config)\n\n## Writer summary for tb\ntb_folder = os.path.join(model_path, 'tb_logs',model_name)\nwriter = SummaryWriter(tb_folder)\nlog_string_config = '  '.join([k+':'+str(v) for k,v in hparams.items()])\nwriter.add_text('config : %s' % model_name, log_string_config, 0)\n\nif epoch_init == 0:\n    model.zero_grad()\n## Start training\ntot_iter = 0\ntotal_loss = 0\ntotal_accu = 0\nfor epoch in range(epoch_init,hparams['epochs']):\n    logger.info(f'Epoch ############: {epoch}')\n    for ib, (img_batch_mmodal, true_labels, manip_type) in enumerate(train_generator,1):\n        img_batch = img_batch_mmodal.float().to(device)\n        true_labels = true_labels.long().to(device)\n        optimizer.zero_grad()\n        pred_labels = model(img_batch)\n        loss = criterion(pred_labels, true_labels)\n        total_loss += loss.item()\n        log_probs = F.softmax(pred_labels, dim=-1)\n        res_probs = torch.argmax(log_probs, dim=-1)\n        summation = torch.sum(res_probs == true_labels)\n        accu = summation / img_batch.shape[0]\n        total_accu += accu\n        loss.backward()\n        optimizer.step()\n        tot_iter += 1\n        if tot_iter % hparams['display_step'] == 0:\n            train_logging(\n                        'loss/train_loss_iter', writer, logger, epoch, saver, \n                        tot_iter, total_loss/hparams['display_step'], \n                        total_accu/hparams['display_step'], lr_scheduler\n                        )\n            with open(log_file_path, \"a+\") as log_file:\n                log_file.write(\n                    f\"Epoch: {epoch}, Iteration: {tot_iter}, \"\n                    f\"Train Loss: {total_loss/hparams['display_step']:.4f}, \"\n                    f\"Accuracy: {total_accu/hparams['display_step']:.4f}\\n\"\n                )\n            total_loss = 0\n            total_accu = 0\n    saver.save_model(epoch,tot_iter,sys.maxsize,before_train,force_saving=True)\n\n    if (epoch % hparams['valid_epoch'] == 0) or (epoch == hparams['epochs']):\n        metrics = eval_model(model,dataset_name,test_generator,criterion,device,desc='valid',val_metrics=None,debug_mode=False)\n        # metrics = eval_model(model,dataset_name,test_generator,criterion,device,desc='valid',val_metrics=None,debug_mode=True)\n        val_loss = metrics.get_avg_loss()\n        saver.save_model(epoch,ib+ib_off,val_loss,before_train,best_only=True)\n        # display_eval_tb(writer,metrics,epoch,desc='valid')\n        display_eval_tb(writer,metrics,epoch,desc='test')\n        lr_scheduler.step(val_loss)\n        sched_monitor.monitor()\n        for i, grp in enumerate(sched_monitor.scheduler.optimizer.param_groups):\n            if 'lr' in grp.keys():\n                print(\"the first grp learning rate is: \", grp['lr'])\n                break\n        file_path = f\"./{exp_name}.txt\"\n        os.makedirs(os.path.dirname(file_path), exist_ok=True)\n        with open(file_path, 'a') as f:\n            f.write(f\"AUC: {metrics.roc.auc}\\n\")\n            f.write(f\"Best Accuracy: {metrics.best_valid_acc} (Threshold: {metrics.best_valid_thr})\\n\")\n            for fpr_value in [0.1, 0.01]:\n                tpr_fpr, score_for_tpr_fpr = metrics.roc.get_tpr_at_fpr(fpr_value)\n                f.write(f\"TPR at FPR={fpr_value*100}%: {tpr_fpr} (Score: {score_for_tpr_fpr})\\n\")\n            f.write(f\"Average Loss: {metrics.get_avg_loss()}\\n\")\n            f.write(\"#\" * 100)"
  },
  {
    "path": "applications/deepfake_detection/train.sh",
    "content": "source ~/.bashrc\nconda activate HiFi_Net_deepfake\nCUDA_NUM=0,1,3,4,5,6\nCUDA_VISIBLE_DEVICES=$CUDA_NUM python train.py \\\n                                --dataset_name FF++ \\\n                                --batch_size 32 \\\n                                --window_size 10 \\\n                                --gpus 6 \\\n                                --valid_epoch 1 \\\n                                --feat_dim 1000 \\\n                                --learning_rate 1e-4 \\\n                                --display_step 150\n"
  },
  {
    "path": "data_dir/CASIA/CASIA1/fake.txt",
    "content": "Sp_D_CND_A_pla0005_pla0023_0281.jpg\nSp_D_CND_A_sec0056_sec0015_0282.jpg\nSp_D_CNN_A_ani0049_ani0084_0266.jpg\n"
  },
  {
    "path": "data_dir/CASIA/CASIA2/fake.txt",
    "content": "Tp_D_CND_M_N_ani00018_sec00096_00138.tif\nTp_D_CND_M_N_art00076_art00077_10289.tif\nTp_D_CND_M_N_art00077_art00076_10290.tif\n"
  },
  {
    "path": "data_dir/Coverage/fake.txt",
    "content": "10t.tif\n11t.tif\n12t.tif\n13t.tif\n14t.tif\n15t.tif\n16t.tif\n17t.tif\n18t.tif\n19t.tif\n1t.tif\n"
  },
  {
    "path": "data_dir/IMD2020/fake.txt",
    "content": "00010_fake_01.jpg\n"
  },
  {
    "path": "data_dir/NIST16/alllist.txt",
    "content": "probe/NC2016_0016.jpg mask/mani_NC2016_0940.png\nprobe/NC2016_0128.jpg mask/mani_NC2016_3942.png\nprobe/NC2016_0130.jpg mask/mani_NC2016_6409.png"
  },
  {
    "path": "data_dir/columbia/vallist.txt",
    "content": "canong3_canonxt_sub_01.tif\ncanong3_canonxt_sub_02.tif\ncanong3_canonxt_sub_03.tif\ncanong3_canonxt_sub_04.tif\ncanong3_canonxt_sub_05.tif\ncanong3_canonxt_sub_06.tif\ncanong3_canonxt_sub_07.tif\ncanong3_canonxt_sub_08.tif\ncanong3_canonxt_sub_09.tif\n"
  },
  {
    "path": "environment.yml",
    "content": "name: HiFi_Net\nchannels:\n  - conda-forge\n  - pytorch\n  - defaults\ndependencies:\n  - _libgcc_mutex=0.1=main\n  - _openmp_mutex=5.1=1_gnu\n  - absl-py=1.3.0=py37h06a4308_0\n  - aiohttp=3.8.3=py37h5eee18b_0\n  - aiosignal=1.2.0=pyhd3eb1b0_0\n  - async-timeout=4.0.2=py37h06a4308_0\n  - asynctest=0.13.0=py_0\n  - attrs=22.1.0=py37h06a4308_0\n  - blas=1.0=mkl\n  - blinker=1.4=py37h06a4308_0\n  - brotlipy=0.7.0=py37h27cfd23_1003\n  - bzip2=1.0.8=h7b6447c_0\n  - c-ares=1.19.1=h5eee18b_0\n  - ca-certificates=2023.12.12=h06a4308_0\n  - cachetools=4.2.2=pyhd3eb1b0_0\n  - certifi=2022.12.7=py37h06a4308_0\n  - cffi=1.15.1=py37h5eee18b_3\n  - charset-normalizer=2.0.4=pyhd3eb1b0_0\n  - click=8.0.4=py37h06a4308_0\n  - cryptography=39.0.1=py37h9ce1e76_0\n  - cudatoolkit=11.3.1=h2bc3f7f_2\n  - cycler=0.11.0=pyhd3eb1b0_0\n  - ffmpeg=4.3=hf484d3e_0\n  - fftw=3.3.9=h27cfd23_1\n  - freetype=2.12.1=h4a9f257_0\n  - frozenlist=1.3.3=py37h5eee18b_0\n  - giflib=5.2.1=h5eee18b_3\n  - gmp=6.2.1=h295c915_3\n  - gnutls=3.6.15=he1e5248_0\n  - google-auth=2.6.0=pyhd3eb1b0_0\n  - google-auth-oauthlib=0.4.4=pyhd3eb1b0_0\n  - grpcio=1.42.0=py37hce63b2e_0\n  - icu=67.1=he1b5a44_0\n  - idna=3.4=py37h06a4308_0\n  - imageio=2.9.0=pyhd3eb1b0_0\n  - importlib-metadata=4.11.3=py37h06a4308_0\n  - intel-openmp=2021.4.0=h06a4308_3561\n  - joblib=1.1.0=pyhd3eb1b0_0\n  - jpeg=9e=h5eee18b_1\n  - kiwisolver=1.4.4=py37h6a678d5_0\n  - lame=3.100=h7b6447c_0\n  - lcms2=2.12=h3be6417_0\n  - ld_impl_linux-64=2.38=h1181459_1\n  - lerc=3.0=h295c915_0\n  - libblas=3.9.0=12_linux64_mkl\n  - libcblas=3.9.0=12_linux64_mkl\n  - libdeflate=1.17=h5eee18b_1\n  - libffi=3.4.4=h6a678d5_0\n  - libgcc-ng=11.2.0=h1234567_1\n  - libgfortran-ng=11.2.0=h00389a5_1\n  - libgfortran5=11.2.0=h1234567_1\n  - libgomp=11.2.0=h1234567_1\n  - libiconv=1.16=h7f8727e_2\n  - libidn2=2.3.4=h5eee18b_0\n  - libpng=1.6.39=h5eee18b_0\n  - libprotobuf=3.20.3=he621ea3_0\n  - libstdcxx-ng=11.2.0=h1234567_1\n  - libtasn1=4.19.0=h5eee18b_0\n  - libtiff=4.5.1=h6a678d5_0\n  - libunistring=0.9.10=h27cfd23_0\n  - libuv=1.44.2=h5eee18b_0\n  - libwebp=1.2.4=h11a3e52_1\n  - libwebp-base=1.2.4=h5eee18b_1\n  - lz4-c=1.9.4=h6a678d5_0\n  - markdown=3.4.1=py37h06a4308_0\n  - markupsafe=2.1.1=py37h7f8727e_0\n  - matplotlib=3.2.2=1\n  - matplotlib-base=3.2.2=py37h1d35a4c_1\n  - mkl=2021.4.0=h06a4308_640\n  - mkl-service=2.4.0=py37h7f8727e_0\n  - mkl_fft=1.3.1=py37hd3c417c_0\n  - mkl_random=1.2.2=py37h51133e4_0\n  - multidict=6.0.2=py37h5eee18b_0\n  - ncurses=6.4=h6a678d5_0\n  - nettle=3.7.3=hbbd107a_1\n  - numpy=1.21.5=py37h6c91a56_3\n  - numpy-base=1.21.5=py37ha15fc14_3\n  - oauthlib=3.2.1=py37h06a4308_0\n  - openh264=2.1.1=h4ff587b_0\n  - openssl=1.1.1w=h7f8727e_0\n  - pillow=9.4.0=py37h6a678d5_0\n  - pip=23.3.2=pyhd8ed1ab_0\n  - protobuf=3.20.3=py37h6a678d5_0\n  - pyasn1=0.4.8=pyhd3eb1b0_0\n  - pyasn1-modules=0.2.8=py_0\n  - pycparser=2.21=pyhd3eb1b0_0\n  - pyjwt=2.4.0=py37h06a4308_0\n  - pyopenssl=23.0.0=py37h06a4308_0\n  - pyparsing=3.0.9=py37h06a4308_0\n  - pysocks=1.7.1=py37_1\n  - python=3.7.16=h7a1cb2a_0\n  - python-dateutil=2.8.2=pyhd3eb1b0_0\n  - python_abi=3.7=2_cp37m\n  - pytorch=1.11.0=py3.7_cuda11.3_cudnn8.2.0_0\n  - pytorch-mutex=1.0=cuda\n  - pyyaml=6.0=py37h5eee18b_1\n  - readline=8.2=h5eee18b_0\n  - requests=2.28.1=py37h06a4308_0\n  - requests-oauthlib=1.3.0=py_0\n  - rsa=4.7.2=pyhd3eb1b0_1\n  - scikit-learn=1.0.2=py37hf9e9bfc_0\n  - scipy=1.7.3=py37h6c91a56_2\n  - setuptools=68.2.2=pyhd8ed1ab_0\n  - six=1.16.0=pyhd3eb1b0_1\n  - sqlite=3.41.2=h5eee18b_0\n  - tensorboard=2.10.0=py37h06a4308_0\n  - tensorboard-data-server=0.6.1=py37h52d8a92_0\n  - tensorboard-plugin-wit=1.8.1=py37h06a4308_0\n  - threadpoolctl=2.2.0=pyh0d69192_0\n  - tk=8.6.12=h1ccaba5_0\n  - torchvision=0.12.0=py37_cu113\n  - tornado=5.1.1=py37h7b6447c_0\n  - tqdm=4.64.1=py37h06a4308_0\n  - typing-extensions=4.3.0=py37h06a4308_0\n  - typing_extensions=4.3.0=py37h06a4308_0\n  - urllib3=1.26.14=py37h06a4308_0\n  - werkzeug=2.2.2=py37h06a4308_0\n  - wheel=0.38.4=py37h06a4308_0\n  - xz=5.4.5=h5eee18b_0\n  - yacs=0.1.6=pyhd3eb1b0_1\n  - yaml=0.2.5=h7b6447c_0\n  - yarl=1.8.1=py37h5eee18b_0\n  - zipp=3.11.0=py37h06a4308_0\n  - zlib=1.2.13=h5eee18b_0\n  - zstd=1.5.5=hc292b87_0\n  - pip:\n      - einops==0.6.1\n      - kmeans-pytorch==0.3\n      - opencv-python==4.8.1.78\nprefix: /home/aya/.conda/envs/HiFi_Net\n"
  },
  {
    "path": "models/GaussianSmoothing.py",
    "content": "# ------------------------------------------------------------------------------\n# Author: Xiao Guo (guoxia11@msu.edu)\n# CVPR2023: Hierarchical Fine-Grained Image Forgery Detection and Localization\n# ------------------------------------------------------------------------------\n\nimport math\nimport numbers\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nclass GaussianSmoothing(nn.Module):\n    \"\"\"\n    Apply gaussian smoothing on a\n    1d, 2d or 3d tensor. Filtering is performed seperately for each channel\n    in the input using a depthwise convolution.\n    Arguments:\n    channels (int, sequence): Number of channels of the input tensors. Output will\n    have this number of channels as well.\n    kernel_size (int, sequence): Size of the gaussian kernel.\n    sigma (float, sequence): Standard deviation of the gaussian kernel.\n    dim (int, optional): The number of dimensions of the data.\n    Default value is 2 (spatial).\n    \"\"\"\n    def __init__(self, channels, kernel_size, sigma, dim=2):\n        super(GaussianSmoothing, self).__init__()\n        if isinstance(kernel_size, numbers.Number):\n            kernel_size = [kernel_size] * dim\n        if isinstance(sigma, numbers.Number):\n            sigma = [sigma] * dim\n\n        # The gaussian kernel is the product of the\n        # gaussian function of each dimension.\n        kernel = 1\n        meshgrids = torch.meshgrid(\n            [\n                torch.arange(size, dtype=torch.float32)\n                for size in kernel_size\n            ], indexing='ij'\n        )\n        for size, std, mgrid in zip(kernel_size, sigma, meshgrids):\n            mean = (size - 1) / 2\n            kernel *= 1 / (std * math.sqrt(2 * math.pi)) * \\\n                      torch.exp(-((mgrid - mean) / std) ** 2 / 2)\n        # Make sure sum of values in gaussian kernel equals 1.\n        kernel = kernel / torch.sum(kernel)\n\n        # Reshape to depthwise convolutional weight\n        kernel = kernel.view(1, 1, *kernel.size())\n        kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1))\n        self.register_buffer('weight', kernel)\n        self.groups = channels\n\n        if dim == 1:\n            self.conv = F.conv1d\n        elif dim == 2:\n            self.conv = F.conv2d\n        elif dim == 3:\n            self.conv = F.conv3d\n        else:\n            raise RuntimeError(\n                'Only 1, 2 and 3 dimensions are supported. Received {}.'.format(dim)\n            )\n\n    def forward(self, input):\n        \"\"\"\n        Apply gaussian filter to input.\n        Arguments:\n        input (torch.Tensor): Input to apply gaussian filter on.\n        Returns:\n        filtered (torch.Tensor): Filtered output.\n        \"\"\"\n        return self.conv(input, weight=self.weight, groups=self.groups)"
  },
  {
    "path": "models/LaPlacianMs.py",
    "content": "# ------------------------------------------------------------------------------\n# Author: Xiao Guo (guoxia11@msu.edu)\n# CVPR2023: Hierarchical Fine-Grained Image Forgery Detection and Localization\n# ------------------------------------------------------------------------------\n\nimport torch\nimport torch.nn as nn\nfrom torch.nn import functional as F\nfrom .GaussianSmoothing import GaussianSmoothing\n\nclass LaPlacianMs(nn.Module):\n    def __init__(self,in_c,gauss_ker_size=3,scale=[2],drop_rate=0.2):\n        super(LaPlacianMs, self).__init__()\n        self.scale = scale\n        self.gauss_ker_size = gauss_ker_size\n        ## apply gaussian smoothing to input feature maps with 3 planes\n        ## with kernel size K and sigma s\n        self.smoothing = nn.ModuleDict()\n        for s in self.scale:\n            self.smoothing['scale-'+str(s)] = GaussianSmoothing(in_c, self.gauss_ker_size, s)\n        self.conv_1x1 = nn.Sequential(nn.Conv2d(in_c*len(scale), in_c,\n                                                kernel_size=1, stride=1,\n                                                bias=False,groups=1),\n                                                nn.BatchNorm2d(in_c),\n                                                nn.ReLU(inplace=True),\n                                                nn.Dropout(p=drop_rate)\n        )\n        # Official init from torch repo.\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                nn.init.kaiming_normal_(m.weight)\n            elif isinstance(m, nn.BatchNorm2d):\n                nn.init.constant_(m.weight, 1)\n                nn.init.constant_(m.bias, 0)\n            elif isinstance(m, nn.Linear):\n                nn.init.constant_(m.bias, 0)\n\n    def down(self,x,s):\n        return F.interpolate(x,scale_factor=s,\n                             mode='bilinear',\n                             align_corners=False)\n    def up (self,x, size):\n        return F.interpolate(x,size=size,mode='bilinear',align_corners=False)\n\n    def forward(self, x):\n        for i, s in enumerate(self.scale):\n            sm = self.smoothing['scale-'+str(s)](x)\n            sm = self.down(sm,1/s)\n            sm = self.up(sm,(x.shape[2],x.shape[3]))\n            if i == 0:\n                diff = x - sm\n            else:\n                diff = torch.cat((diff, x - sm), dim=1)\n        return self.conv_1x1(diff)"
  },
  {
    "path": "models/NLCDetection_api.py",
    "content": "# ------------------------------------------------------------------------------\n# Author: Xiao Guo (guoxia11@msu.edu)\n# CVPR2023: Hierarchical Fine-Grained Image Forgery Detection and Localization\n# ------------------------------------------------------------------------------\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom models.seg_hrnet_config import get_cfg_defaults\nimport time\n\ndef weights_init(init_type='gaussian'):\n    def init_fun(m):\n        classname = m.__class__.__name__\n        if (classname.find('Conv') == 0 or classname.find(\n                'Linear') == 0) and hasattr(m, 'weight'):\n            if init_type == 'gaussian':\n                nn.init.normal_(m.weight, 0.0, 0.02)\n            elif init_type == 'xavier':\n                nn.init.xavier_normal_(m.weight, gain=math.sqrt(2))\n            elif init_type == 'kaiming':\n                nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in')\n            elif init_type == 'orthogonal':\n                nn.init.orthogonal_(m.weight, gain=math.sqrt(2))\n            elif init_type == 'default':\n                pass\n            else:\n                assert 0, \"Unsupported initialization: {}\".format(init_type)\n            if hasattr(m, 'bias') and m.bias is not None:\n                nn.init.constant_(m.bias, 0.0)\n    return init_fun\n\nclass PartialConv(nn.Module):\n    def __init__(self, in_channels, out_channels, kernel_size, stride=1,\n                 padding=0, dilation=1, groups=1, bias=True):\n        super().__init__()\n        self.input_conv = nn.Conv2d(in_channels, out_channels, kernel_size,\n                                    stride, padding, dilation, groups, bias)\n        self.mask_conv  = nn.Conv2d(in_channels, out_channels, kernel_size,\n                                    stride, padding, dilation, groups, False)\n        self.input_conv.apply(weights_init('kaiming'))\n        torch.nn.init.constant_(self.mask_conv.weight, 1.0)\n        # mask is not updated\n        for param in self.mask_conv.parameters():\n            param.requires_grad = False\n\n    def forward(self, input, mask):\n        # http://masc.cs.gmu.edu/wiki/partialconv\n        # C(X) = W^T * X + b, C(0) = b, D(M) = 1 * M + 0 = sum(M)\n        # W^T* (M .* X) / sum(M) + b = [C(M .* X) – C(0)] / D(M) + C(0)\n\n        ## GX: masking the input outside function.\n        output = self.input_conv(input)\n        if self.input_conv.bias is not None:\n            output_bias = self.input_conv.bias.view(1, -1, 1, 1).expand_as(output)\n        else:\n            output_bias = torch.zeros_like(output)        \n\n        with torch.no_grad():\n            output_mask = self.mask_conv(mask)\n\n        no_update_holes = output_mask == 0\n\n        ## in output_mask, fills the 0-value-position with 1.0\n        ## without this step, math error occurs.\n        mask_sum = output_mask.masked_fill_(no_update_holes, 1.0)\n        output_pre = (output - output_bias) / mask_sum + output_bias\n        output = output_pre.masked_fill_(no_update_holes, 0.0)\n\n        new_mask = torch.ones_like(output)\n        new_mask = new_mask.masked_fill_(no_update_holes, 0.0)\n        \n        return output, new_mask\n\nclass NonLocalMask(nn.Module):\n    def __init__(self, in_channels, reduce_scale):\n        super(NonLocalMask, self).__init__()\n\n        self.r = reduce_scale\n\n        # input channel number\n        self.ic = in_channels * self.r * self.r\n\n        # middle channel number\n        self.mc = self.ic\n\n        self.g = nn.Conv2d(in_channels=self.ic, out_channels=self.ic,\n                           kernel_size=1, stride=1, padding=0)\n\n        self.theta = nn.Conv2d(in_channels=self.ic, out_channels=self.mc,\n                               kernel_size=1, stride=1, padding=0)\n        self.phi = nn.Conv2d(in_channels=self.ic, out_channels=self.mc,\n                             kernel_size=1, stride=1, padding=0)\n        self.W_s = nn.Conv2d(in_channels=in_channels, out_channels=in_channels,\n                             kernel_size=1, stride=1, padding=0)\n\n        self.gamma_s = nn.Parameter(torch.ones(1))\n        self.getmask = nn.Sequential(\n                                    nn.Conv2d(in_channels=in_channels, out_channels=16, \n                                              kernel_size=3, stride=1, padding=1),\n                                    nn.ReLU(),\n                                    nn.Conv2d(in_channels=16, out_channels=1, kernel_size=3, stride=1, padding=1)\n                                    )\n\n        ## Pconv\n        self.Pconv_1 = PartialConv(3, 3, kernel_size=3, stride=2)\n        self.Pconv_2 = PartialConv(3, 3, kernel_size=3, stride=2)\n        self.Pconv_3 = PartialConv(3, 1, kernel_size=3, stride=2)\n\n    def forward(self, x, img):\n        b, c, h, w = x.shape\n\n        x1 = x.reshape(b, self.ic, h // self.r, w // self.r)\n\n        # g x\n        g_x = self.g(x1).view(b, self.ic, -1)\n        g_x = g_x.permute(0, 2, 1)\n\n        # theta\n        theta_x = self.theta(x1).view(b, self.mc, -1)\n        theta_x_s = theta_x.permute(0, 2, 1)\n\n        # phi x\n        phi_x = self.phi(x1).view(b, self.mc, -1)\n        phi_x_s = phi_x\n\n        # non-local attention\n        f_s = torch.matmul(theta_x_s, phi_x_s)\n        f_s_div = F.softmax(f_s, dim=-1)\n\n        # get y_s\n        y_s = torch.matmul(f_s_div, g_x)\n        y_s = y_s.permute(0, 2, 1).contiguous()\n        y_s = y_s.view(b, c, h, w)\n\n        # GX: (256,256,18), output mask for the deep metric loss.\n        mask_feat = x + self.gamma_s * self.W_s(y_s)\n\n        # get 1-dimensional mask_tmp\n        mask_binary = torch.sigmoid(self.getmask(mask_feat))\n        mask_tmp = mask_binary.repeat(1, 3, 1, 1)\n        mask_img = img * mask_tmp # mask_img is the overlaid image.\n\n        ## conv output\n        x, new_mask = self.Pconv_1(mask_img, mask_tmp)\n        x, new_mask = self.Pconv_2(x, new_mask)\n        x, _        = self.Pconv_3(x, new_mask)\n        mask_binary = mask_binary.squeeze(dim=1)\n        return x, mask_feat, mask_binary\n\nclass Flatten(nn.Module):\n    def __init__(self):\n        super(Flatten, self).__init__()\n        \n    def forward(self, x):       \n        return x.view(x.size(0), -1)\n\nclass Classifer(nn.Module):\n    def __init__(self, in_channels, output_channels):\n        super(Classifer, self).__init__()\n        self.pool = nn.Sequential(\n                                  # nn.AdaptiveAvgPool2d((1,1)),\n                                  nn.AdaptiveAvgPool2d(1),\n                                  Flatten()\n                                )\n        self.fc = nn.Linear(in_channels, output_channels, bias=True)\n        self.relu = nn.ReLU(inplace=True)\n\n    def forward(self, x):\n        feat = self.pool(x)\n        feat = self.relu(feat)\n        cls_res = self.fc(feat)\n        return cls_res\n\nclass BranchCLS(nn.Module):\n    def __init__(self, in_channels, output_channels):\n        super(BranchCLS, self).__init__()\n        self.pool = nn.Sequential(nn.AdaptiveAvgPool2d((1,1)),\n                                  Flatten()\n                                )\n        self.fc = nn.Linear(18, output_channels, bias=True)\n        self.bn = nn.BatchNorm1d(18, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n        self.branch_cls = nn.Sequential(nn.Conv2d(in_channels=in_channels, out_channels=32, \n                                                  padding=1, kernel_size=3, stride=1),\n                                        nn.ReLU(inplace=True),\n                                        nn.Conv2d(in_channels=32, out_channels=18,\n                                                  padding=1, kernel_size=3, stride=1),\n                                        nn.ReLU(inplace=True), \n                                        )\n        self.leakyrelu = nn.LeakyReLU(0.2)\n\n    def forward(self, x):\n        feat = self.branch_cls(x)\n        x = self.pool(feat)\n        x = self.bn(x)\n        cls_res = self.fc(x)\n        cls_pro = self.leakyrelu(cls_res)\n        zero_vec = -9e15*torch.ones_like(cls_pro)\n        cls_pro  = torch.where(cls_pro > 0, cls_pro, zero_vec)\n        return cls_res, cls_pro, feat\n\nclass NLCDetection(nn.Module):\n    def __init__(self):\n        super(NLCDetection, self).__init__()\n        self.split_tensor_1 = torch.tensor([1, 3]).cuda()\n        self.split_tensor_2 = torch.tensor([1, 2, 1, 3]).cuda()\n        self.softmax_m = nn.Softmax(dim=1)\n        FENet_cfg = get_cfg_defaults()\n        feat1_num, feat2_num, feat3_num, feat4_num = FENet_cfg['STAGE4']['NUM_CHANNELS']\n\n        ## mask generation branch.\n        self.getmask = NonLocalMask(feat1_num, 4)\n\n        ## classification branch.\n        self.branch_cls_level_1 = BranchCLS(271, 14)   # 252 + 18 = 270\n        self.branch_cls_level_2 = BranchCLS(252, 7)    # 144+72+36 = 252\n        self.branch_cls_level_3 = BranchCLS(216, 5)    # 144+72 = 216\n        self.branch_cls_level_4 = BranchCLS(144, 3)    # 144\n\n    def forward(self, feat, img):\n        s1, s2, s3, s4 = feat\n\n        pconv_feat, mask, mask_binary = self.getmask(s1, img)\n        pconv_feat = pconv_feat.clone().detach()\n\n        pconv_1 = F.interpolate(pconv_feat, size=s1.size()[2:], mode='bilinear', align_corners=True)\n\n        ## forth branch.\n        cls_4, pro_4, _ = self.branch_cls_level_4(s4)\n        cls_prob_4      = self.softmax_m(pro_4)\n        cls_prob_40 = torch.unsqueeze(cls_prob_4[:,0],1)\n        cls_prob_41 = torch.unsqueeze(cls_prob_4[:,1],1)\n        cls_prob_42 = torch.unsqueeze(cls_prob_4[:,2],1)\n        cls_prob_mask_3 = torch.cat([cls_prob_40, cls_prob_41, cls_prob_41, cls_prob_42, cls_prob_42],axis=1)\n\n        ## third branch\n        s4F = F.interpolate(s4, size=s3.size()[2:], mode='bilinear', align_corners=True)\n        s3_input = torch.cat([s4F, s3], axis=1)\n        cls_3, pro_3, _ = self.branch_cls_level_3(s3_input)\n        cls_prob_3      = self.softmax_m(pro_3)\n        cls_3 = cls_3 + cls_3 * cls_prob_mask_3\n        cls_prob_30 = torch.unsqueeze(cls_prob_3[:,0],1)\n        cls_prob_31 = torch.unsqueeze(cls_prob_3[:,1],1)\n        cls_prob_32 = torch.unsqueeze(cls_prob_3[:,2],1)\n        cls_prob_33 = torch.unsqueeze(cls_prob_3[:,3],1)\n        cls_prob_34 = torch.unsqueeze(cls_prob_3[:,4],1)\n        cls_prob_mask_2 = torch.cat([cls_prob_30, cls_prob_31, cls_prob_31, \n                                     cls_prob_32, cls_prob_32,\n                                     cls_prob_33, cls_prob_34],axis=1)\n\n        ## second branch\n        s3F = F.interpolate(s3_input, size=s2.size()[2:], mode='bilinear', align_corners=True)\n        s2_input = torch.cat([s3F, s2], axis=1)\n        cls_2, pro_2, _ = self.branch_cls_level_2(s2_input) \n        cls_prob_2      = self.softmax_m(pro_2)\n        cls_2 = cls_2 + cls_2 * cls_prob_mask_2\n        cls_prob_20 = torch.unsqueeze(cls_prob_2[:,0],1)\n        cls_prob_21 = torch.unsqueeze(cls_prob_2[:,1],1)\n        cls_prob_22 = torch.unsqueeze(cls_prob_2[:,2],1)\n        cls_prob_23 = torch.unsqueeze(cls_prob_2[:,3],1)\n        cls_prob_24 = torch.unsqueeze(cls_prob_2[:,4],1)\n        cls_prob_25 = torch.unsqueeze(cls_prob_2[:,4],1)\n        cls_prob_26 = torch.unsqueeze(cls_prob_2[:,4],1)\n        cls_prob_mask_1 = torch.cat([cls_prob_20, \n                                     cls_prob_21, cls_prob_21, cls_prob_22, cls_prob_22,    # 4 diffusion\n                                     cls_prob_23, cls_prob_23, cls_prob_24, cls_prob_24,    # 4 gan\n                                     cls_prob_25, cls_prob_25,                              # faceshifter+stgan\n                                     cls_prob_26, cls_prob_26, cls_prob_26], axis=1)        # 3 editing\n\n        s2F = F.interpolate(s2_input, size=s1.size()[2:], mode='bilinear', align_corners=True)\n        s1_input = torch.cat([s2F, s1, pconv_1], axis=1)\n        cls_1, pro_1, _ = self.branch_cls_level_1(s1_input) \n        cls_1 = cls_1 + cls_1 * cls_prob_mask_1\n        return mask, mask_binary, cls_4, cls_3, cls_2, cls_1"
  },
  {
    "path": "models/NLCDetection_loc.py",
    "content": "# ------------------------------------------------------------------------------\n# Author: Xiao Guo (guoxia11@msu.edu)\n# CVPR2023: Hierarchical Fine-Grained Image Forgery Detection and Localization\n# ------------------------------------------------------------------------------\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom models.seg_hrnet_config import get_cfg_defaults\nimport time\n\ndef weights_init(init_type='gaussian'):\n    def init_fun(m):\n        classname = m.__class__.__name__\n        if (classname.find('Conv') == 0 or classname.find(\n                'Linear') == 0) and hasattr(m, 'weight'):\n            if init_type == 'gaussian':\n                nn.init.normal_(m.weight, 0.0, 0.02)\n            elif init_type == 'xavier':\n                nn.init.xavier_normal_(m.weight, gain=math.sqrt(2))\n            elif init_type == 'kaiming':\n                nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in')\n            elif init_type == 'orthogonal':\n                nn.init.orthogonal_(m.weight, gain=math.sqrt(2))\n            elif init_type == 'default':\n                pass\n            else:\n                assert 0, \"Unsupported initialization: {}\".format(init_type)\n            if hasattr(m, 'bias') and m.bias is not None:\n                nn.init.constant_(m.bias, 0.0)\n    return init_fun\n\nclass PartialConv(nn.Module):\n    def __init__(self, in_channels, out_channels, kernel_size, stride=1,\n                 padding=0, dilation=1, groups=1, bias=True):\n        super().__init__()\n        self.input_conv = nn.Conv2d(in_channels, out_channels, kernel_size,\n                                    stride, padding, dilation, groups, bias)\n        self.mask_conv  = nn.Conv2d(in_channels, out_channels, kernel_size,\n                                    stride, padding, dilation, groups, False)\n        self.input_conv.apply(weights_init('kaiming'))\n        torch.nn.init.constant_(self.mask_conv.weight, 1.0)\n        # mask is not updated\n        for param in self.mask_conv.parameters():\n            param.requires_grad = False\n\n    def forward(self, input, mask):\n        # http://masc.cs.gmu.edu/wiki/partialconv\n        # C(X) = W^T * X + b, C(0) = b, D(M) = 1 * M + 0 = sum(M)\n        # W^T* (M .* X) / sum(M) + b = [C(M .* X) – C(0)] / D(M) + C(0)\n\n        ## GX: masking the input outside function.\n        output = self.input_conv(input)\n        if self.input_conv.bias is not None:\n            output_bias = self.input_conv.bias.view(1, -1, 1, 1).expand_as(output)\n        else:\n            output_bias = torch.zeros_like(output)        \n\n        with torch.no_grad():\n            output_mask = self.mask_conv(mask)\n\n        no_update_holes = output_mask == 0\n\n        ## in output_mask, fills the 0-value-position with 1.0\n        ## without this step, math error occurs.\n        mask_sum = output_mask.masked_fill_(no_update_holes, 1.0)\n        output_pre = (output - output_bias) / mask_sum + output_bias\n        output = output_pre.masked_fill_(no_update_holes, 0.0)\n\n        new_mask = torch.ones_like(output)\n        new_mask = new_mask.masked_fill_(no_update_holes, 0.0)\n        \n        return output, new_mask\n\nclass NonLocalMask(nn.Module):\n    def __init__(self, in_channels, reduce_scale):\n        super(NonLocalMask, self).__init__()\n\n        self.r = reduce_scale\n\n        # input channel number\n        self.ic = in_channels * self.r * self.r\n\n        # middle channel number\n        self.mc = self.ic\n\n        self.g = nn.Conv2d(in_channels=self.ic, out_channels=self.ic,\n                           kernel_size=1, stride=1, padding=0)\n\n        self.theta = nn.Conv2d(in_channels=self.ic, out_channels=self.mc,\n                               kernel_size=1, stride=1, padding=0)\n        self.phi = nn.Conv2d(in_channels=self.ic, out_channels=self.mc,\n                             kernel_size=1, stride=1, padding=0)\n        self.W_s = nn.Conv2d(in_channels=in_channels, out_channels=in_channels,\n                             kernel_size=1, stride=1, padding=0)\n\n        self.gamma_s = nn.Parameter(torch.ones(1))\n        self.conv_1 = nn.Conv2d(in_channels=in_channels, out_channels=18,\n                                kernel_size=3, stride=1, padding=1)\n        self.relu = nn.ReLU()\n        self.conv_2 = nn.Conv2d(in_channels=18, out_channels=1, \n                                kernel_size=3, stride=1, padding=1)\n\n        ## Pconv\n        self.Pconv_1 = PartialConv(3, 3, kernel_size=3, stride=2)\n        self.Pconv_2 = PartialConv(3, 3, kernel_size=3, stride=2)\n        self.Pconv_3 = PartialConv(3, 1, kernel_size=3, stride=2)\n\n    def forward(self, x, img):\n        b, c, h, w = x.shape\n\n        x1 = x.reshape(b, self.ic, h // self.r, w // self.r)\n\n        # g x\n        g_x = self.g(x1).view(b, self.ic, -1)\n        g_x = g_x.permute(0, 2, 1)\n\n        # theta\n        theta_x = self.theta(x1).view(b, self.mc, -1)\n        theta_x_s = theta_x.permute(0, 2, 1)\n\n        # phi x\n        phi_x = self.phi(x1).view(b, self.mc, -1)\n        phi_x_s = phi_x\n\n        # non-local attention\n        f_s = torch.matmul(theta_x_s, phi_x_s)\n        f_s_div = F.softmax(f_s, dim=-1)\n\n        # get y_s\n        y_s = torch.matmul(f_s_div, g_x)\n        y_s = y_s.permute(0, 2, 1).contiguous()\n        y_s = y_s.view(b, c, h, w)\n\n        # GX: (256,256,18), output mask for the deep metric loss.\n        mask_feat = x + self.gamma_s * self.W_s(y_s)\n\n        # get 1-dimensional mask_tmp\n        # mask_binary = self.getmask(mask_feat)\n        mask_feat = self.conv_1(mask_feat)\n        mask_binary = mask_feat\n        mask_binary = self.relu(mask_binary)\n        # print(\"mask_feat: \", mask_feat.size())  # torch.Size([4, 18, 256, 256])\n        mask_binary = self.conv_2(mask_binary)\n        # print(\"mask_binary: \", mask_binary.size())  # torch.Size([4, 1, 256, 256])\n        mask_binary = torch.sigmoid(mask_binary)\n        mask_tmp = mask_binary.repeat(1, 3, 1, 1)\n        mask_img = img * mask_tmp # mask_img is the overlaid image.\n\n        ## conv output\n        x, new_mask = self.Pconv_1(mask_img, mask_tmp)\n        x, new_mask = self.Pconv_2(x, new_mask)\n        x, _        = self.Pconv_3(x, new_mask)\n        mask_binary = mask_binary.squeeze(dim=1)\n        return x, torch.sigmoid(mask_feat), mask_binary\n\nclass Flatten(nn.Module):\n    def __init__(self):\n        super(Flatten, self).__init__()\n        \n    def forward(self, x):       \n        return x.view(x.size(0), -1)\n\nclass Classifer(nn.Module):\n    def __init__(self, in_channels, output_channels):\n        super(Classifer, self).__init__()\n        self.pool = nn.Sequential(\n                                  # nn.AdaptiveAvgPool2d((1,1)),\n                                  nn.AdaptiveAvgPool2d(1),\n                                  Flatten()\n                                )\n        self.fc = nn.Linear(in_channels, output_channels, bias=True)\n        self.relu = nn.ReLU(inplace=True)\n\n    def forward(self, x):\n        feat = self.pool(x)\n        feat = self.relu(feat)\n        cls_res = self.fc(feat)\n        return cls_res\n\nclass BranchCLS(nn.Module):\n    def __init__(self, in_channels, output_channels):\n        super(BranchCLS, self).__init__()\n        self.pool = nn.Sequential(nn.AdaptiveAvgPool2d((1,1)),\n                                  Flatten()\n                                )\n        self.fc = nn.Linear(18, output_channels, bias=True)\n        self.bn = nn.BatchNorm1d(18, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n        self.branch_cls = nn.Sequential(nn.Conv2d(in_channels=in_channels, out_channels=32, \n                                                  padding=1, kernel_size=3, stride=1),\n                                        nn.ReLU(inplace=True),\n                                        nn.Conv2d(in_channels=32, out_channels=18,\n                                                  padding=1, kernel_size=3, stride=1),\n                                        nn.ReLU(inplace=True), \n                                        )\n        self.leakyrelu = nn.LeakyReLU(0.2)\n\n    def forward(self, x):\n        feat = self.branch_cls(x)\n        x = self.pool(feat)\n        x = self.bn(x)\n        cls_res = self.fc(x)\n        cls_pro = self.leakyrelu(cls_res)\n        zero_vec = -9e15*torch.ones_like(cls_pro)\n        cls_pro  = torch.where(cls_pro > 0, cls_pro, zero_vec)\n        return cls_res, cls_pro, feat\n\nclass FPN_loc(nn.Module):\n    '''self-implementation Feature Pyramid Networks '''\n    def __init__(self, args, clip_dim=64, multi_feat=None):\n        super(FPN_loc, self).__init__()\n        ## obtain the dimensions. \n        feat1_num, feat2_num, feat3_num, feat4_num = multi_feat\n\n        self.smooth_s4 = nn.Sequential(\n                                    nn.Conv2d(feat4_num, clip_dim, kernel_size=(1, 1), stride=(1, 1)),\n                                    nn.Conv2d(clip_dim, clip_dim, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n                                    )\n        self.smooth_s3 = nn.Sequential(\n                                    nn.Conv2d(feat3_num, clip_dim, kernel_size=(1, 1), stride=(1, 1)),\n                                    nn.Conv2d(clip_dim, clip_dim, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n                                    )\n        self.smooth_s2 = nn.Sequential(\n                                    nn.Conv2d(feat2_num, clip_dim, kernel_size=(1, 1), stride=(1, 1)),\n                                    nn.Conv2d(clip_dim, clip_dim, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n                                    )\n        self.smooth_s1 = nn.Sequential(\n                                    nn.Conv2d(feat1_num, clip_dim, kernel_size=(1, 1), stride=(1, 1)),\n                                    nn.Conv2d(clip_dim, clip_dim, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n                                    )\n\n        ## new branch.\n        self.fpn1 = nn.Sequential(\n            nn.Conv2d(clip_dim, clip_dim, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False),\n            nn.BatchNorm2d(clip_dim),\n            nn.ReLU(),\n            # nn.Upsample(scale_factor=2)\n        )\n\n        self.fpn2 = nn.Sequential(\n            nn.Conv2d(clip_dim, clip_dim, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False),\n            nn.BatchNorm2d(clip_dim),\n            nn.ReLU(),\n            nn.Upsample(scale_factor=2)\n        )\n\n        self.fpn3 = nn.Sequential(\n            nn.Conv2d(clip_dim, clip_dim, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False),\n            nn.BatchNorm2d(clip_dim),\n            nn.ReLU(),\n            nn.Upsample(scale_factor=2),\n        )\n\n        self.fpn4 = nn.Sequential(\n            nn.Conv2d(clip_dim, clip_dim, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False),\n            nn.BatchNorm2d(clip_dim),\n            nn.ReLU(),\n            nn.Upsample(scale_factor=2),\n        )\n\n        smooth_ops = [self.smooth_s4, self.smooth_s3, self.smooth_s2, self.smooth_s1]\n        fpn_ops = [self.fpn4, self.fpn3, self.fpn2, self.fpn1]\n\nclass NLCDetection(nn.Module):\n    def __init__(self):\n        super(NLCDetection, self).__init__()\n        self.crop_size = (256, 256)\n        self.split_tensor_1 = torch.tensor([1, 3]).cuda()\n        self.split_tensor_2 = torch.tensor([1, 2, 1, 3]).cuda()\n        self.softmax_m = nn.Softmax(dim=1)\n        FENet_cfg = get_cfg_defaults()\n        feat1_num, feat2_num, feat3_num, feat4_num = FENet_cfg['STAGE4']['NUM_CHANNELS']\n\n        ## mask generation branch.\n        feat_dim = 64 # large clip_dim will ruin the space of Multi-branch-feature-extractor\n        self.getmask = NonLocalMask(feat_dim, 4)\n        self.FPN_LOC = FPN_loc(feat_dim, multi_feat=FENet_cfg['STAGE4']['NUM_CHANNELS'])\n\n        ## classification branch.\n        self.branch_cls_level_1 = BranchCLS(317, 14)   # 252 + 64 = 316\n        self.branch_cls_level_2 = BranchCLS(252, 7)    # 144+72+36 = 252\n        self.branch_cls_level_3 = BranchCLS(216, 5)    # 144+72 = 216\n        self.branch_cls_level_4 = BranchCLS(144, 3)    # 144\n\n    def feature_resize(self, feat):\n        '''first obtain the mask via the progressive scheme.'''\n        s1, s2, s3, s4 = feat\n        s1 = F.interpolate(s1, size=self.crop_size, mode='bilinear', align_corners=True)\n        s2 = F.interpolate(s2, size=[i // 2 for i in self.crop_size], mode='bilinear', align_corners=True)\n        s3 = F.interpolate(s3, size=[i // 4 for i in self.crop_size], mode='bilinear', align_corners=True)\n        s4 = F.interpolate(s4, size=[i // 8 for i in self.crop_size], mode='bilinear', align_corners=True)\n        return s1, s2, s3, s4\n\n    def forward(self, feat, img):\n\n        s1, s2, s3, s4 = self.feature_resize(feat)\n        img = F.interpolate(img, size=self.crop_size, \n                            mode='bilinear', align_corners=True)\n\n        feat_4 = self.FPN_LOC.smooth_s4(s4)\n        feat_4 = self.FPN_LOC.fpn4(feat_4)   \n        feat_3 = self.FPN_LOC.smooth_s3(s3)\n        feat_3 = self.FPN_LOC.fpn3(feat_3+feat_4)   \n        feat_2 = self.FPN_LOC.smooth_s2(s2)\n        feat_2 = self.FPN_LOC.fpn2(feat_2+feat_3)   \n        feat_1 = self.FPN_LOC.smooth_s1(s1)\n        s1 = self.FPN_LOC.fpn1(feat_1+feat_2)   \n        pconv_feat, mask, mask_binary = self.getmask(s1, img)\n        pconv_feat = pconv_feat.clone().detach()\n\n        pconv_1 = F.interpolate(pconv_feat, size=s1.size()[2:], mode='bilinear', align_corners=True)\n\n        ## forth branch.\n        cls_4, pro_4, _ = self.branch_cls_level_4(s4)\n        cls_prob_4      = self.softmax_m(pro_4)\n        cls_prob_40 = torch.unsqueeze(cls_prob_4[:,0],1)\n        cls_prob_41 = torch.unsqueeze(cls_prob_4[:,1],1)\n        cls_prob_42 = torch.unsqueeze(cls_prob_4[:,2],1)\n        cls_prob_mask_3 = torch.cat([cls_prob_40, cls_prob_41, cls_prob_41, cls_prob_42, cls_prob_42],axis=1)\n\n        ## third branch\n        s4F = F.interpolate(s4, size=s3.size()[2:], mode='bilinear', align_corners=True)\n        s3_input = torch.cat([s4F, s3], axis=1)\n        cls_3, pro_3, _ = self.branch_cls_level_3(s3_input)\n        cls_prob_3      = self.softmax_m(pro_3)\n        cls_3 = cls_3 + cls_3 * cls_prob_mask_3\n        cls_prob_30 = torch.unsqueeze(cls_prob_3[:,0],1)\n        cls_prob_31 = torch.unsqueeze(cls_prob_3[:,1],1)\n        cls_prob_32 = torch.unsqueeze(cls_prob_3[:,2],1)\n        cls_prob_33 = torch.unsqueeze(cls_prob_3[:,3],1)\n        cls_prob_34 = torch.unsqueeze(cls_prob_3[:,4],1)\n        cls_prob_mask_2 = torch.cat([cls_prob_30, cls_prob_31, cls_prob_31, \n                                     cls_prob_32, cls_prob_32,\n                                     cls_prob_33, cls_prob_34],axis=1)\n\n        ## second branch\n        s3F = F.interpolate(s3_input, size=s2.size()[2:], mode='bilinear', align_corners=True)\n        s2_input = torch.cat([s3F, s2], axis=1)\n        cls_2, pro_2, _ = self.branch_cls_level_2(s2_input) \n        cls_prob_2      = self.softmax_m(pro_2)\n        cls_2 = cls_2 + cls_2 * cls_prob_mask_2\n        cls_prob_20 = torch.unsqueeze(cls_prob_2[:,0],1)\n        cls_prob_21 = torch.unsqueeze(cls_prob_2[:,1],1)\n        cls_prob_22 = torch.unsqueeze(cls_prob_2[:,2],1)\n        cls_prob_23 = torch.unsqueeze(cls_prob_2[:,3],1)\n        cls_prob_24 = torch.unsqueeze(cls_prob_2[:,4],1)\n        cls_prob_25 = torch.unsqueeze(cls_prob_2[:,4],1)\n        cls_prob_26 = torch.unsqueeze(cls_prob_2[:,4],1)\n        cls_prob_mask_1 = torch.cat([cls_prob_20, \n                                     cls_prob_21, cls_prob_21, cls_prob_22, cls_prob_22,    # 4 diffusion\n                                     cls_prob_23, cls_prob_23, cls_prob_24, cls_prob_24,    # 4 gan\n                                     cls_prob_25, cls_prob_25,                              # faceshifter+stgan\n                                     cls_prob_26, cls_prob_26, cls_prob_26], axis=1)        # 3 editing\n\n        s2F = F.interpolate(s2_input, size=s1.size()[2:], mode='bilinear', align_corners=True)\n        s1_input = torch.cat([s2F, s1, pconv_1], axis=1)\n        cls_1, pro_1, _ = self.branch_cls_level_1(s1_input) \n        cls_1 = cls_1 + cls_1 * cls_prob_mask_1\n\n        mask = mask.squeeze(dim=1)\n        return mask, mask_binary, cls_4, cls_3, cls_2, cls_1\n"
  },
  {
    "path": "models/NLCDetection_pconv.py",
    "content": "# ------------------------------------------------------------------------------\n# Author: Xiao Guo (guoxia11@msu.edu)\n# CVPR2023: Hierarchical Fine-Grained Image Forgery Detection and Localization\n# ------------------------------------------------------------------------------\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom models.seg_hrnet_config import get_cfg_defaults\nimport time\n\ndef weights_init(init_type='gaussian'):\n    def init_fun(m):\n        classname = m.__class__.__name__\n        if (classname.find('Conv') == 0 or classname.find(\n                'Linear') == 0) and hasattr(m, 'weight'):\n            if init_type == 'gaussian':\n                nn.init.normal_(m.weight, 0.0, 0.02)\n            elif init_type == 'xavier':\n                nn.init.xavier_normal_(m.weight, gain=math.sqrt(2))\n            elif init_type == 'kaiming':\n                nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in')\n            elif init_type == 'orthogonal':\n                nn.init.orthogonal_(m.weight, gain=math.sqrt(2))\n            elif init_type == 'default':\n                pass\n            else:\n                assert 0, \"Unsupported initialization: {}\".format(init_type)\n            if hasattr(m, 'bias') and m.bias is not None:\n                nn.init.constant_(m.bias, 0.0)\n    return init_fun\n\nclass PartialConv(nn.Module):\n    def __init__(self, in_channels, out_channels, kernel_size, stride=1,\n                 padding=0, dilation=1, groups=1, bias=True):\n        super().__init__()\n        self.input_conv = nn.Conv2d(in_channels, out_channels, kernel_size,\n                                    stride, padding, dilation, groups, bias)\n        self.mask_conv  = nn.Conv2d(in_channels, out_channels, kernel_size,\n                                    stride, padding, dilation, groups, False)\n        self.input_conv.apply(weights_init('kaiming'))\n        torch.nn.init.constant_(self.mask_conv.weight, 1.0)\n        # mask is not updated\n        for param in self.mask_conv.parameters():\n            param.requires_grad = False\n\n    def forward(self, input, mask):\n        # http://masc.cs.gmu.edu/wiki/partialconv\n        # C(X) = W^T * X + b, C(0) = b, D(M) = 1 * M + 0 = sum(M)\n        # W^T* (M .* X) / sum(M) + b = [C(M .* X) – C(0)] / D(M) + C(0)\n\n        ## GX: masking the input outside function.\n        output = self.input_conv(input)\n        if self.input_conv.bias is not None:\n            output_bias = self.input_conv.bias.view(1, -1, 1, 1).expand_as(output)\n        else:\n            output_bias = torch.zeros_like(output)        \n\n        with torch.no_grad():\n            output_mask = self.mask_conv(mask)\n\n        no_update_holes = output_mask == 0\n\n        ## in output_mask, fills the 0-value-position with 1.0\n        ## without this step, math error occurs.\n        mask_sum = output_mask.masked_fill_(no_update_holes, 1.0)\n        output_pre = (output - output_bias) / mask_sum + output_bias\n        output = output_pre.masked_fill_(no_update_holes, 0.0)\n\n        new_mask = torch.ones_like(output)\n        new_mask = new_mask.masked_fill_(no_update_holes, 0.0)\n        \n        return output, new_mask\n\nclass NonLocalMask(nn.Module):\n    def __init__(self, in_channels, reduce_scale):\n        super(NonLocalMask, self).__init__()\n\n        self.r = reduce_scale\n\n        # input channel number\n        self.ic = in_channels * self.r * self.r\n\n        # middle channel number\n        self.mc = self.ic\n\n        self.g = nn.Conv2d(in_channels=self.ic, out_channels=self.ic,\n                           kernel_size=1, stride=1, padding=0)\n\n        self.theta = nn.Conv2d(in_channels=self.ic, out_channels=self.mc,\n                               kernel_size=1, stride=1, padding=0)\n        self.phi = nn.Conv2d(in_channels=self.ic, out_channels=self.mc,\n                             kernel_size=1, stride=1, padding=0)\n        self.W_s = nn.Conv2d(in_channels=in_channels, out_channels=in_channels,\n                             kernel_size=1, stride=1, padding=0)\n\n        self.gamma_s = nn.Parameter(torch.ones(1))\n        self.getmask = nn.Sequential(\n                                    nn.Conv2d(in_channels=in_channels, out_channels=16, \n                                              kernel_size=3, stride=1, padding=1),\n                                    nn.ReLU(),\n                                    nn.Conv2d(in_channels=16, out_channels=1, kernel_size=3, stride=1, padding=1)\n                                    )\n\n        ## Pconv\n        self.Pconv_1 = PartialConv(3, 3, kernel_size=3, stride=2)\n        self.Pconv_2 = PartialConv(3, 3, kernel_size=3, stride=2)\n        self.Pconv_3 = PartialConv(3, 1, kernel_size=3, stride=2)\n\n    def forward(self, x, img):\n        b, c, h, w = x.shape\n\n        x1 = x.reshape(b, self.ic, h // self.r, w // self.r)\n\n        # g x\n        g_x = self.g(x1).view(b, self.ic, -1)\n        g_x = g_x.permute(0, 2, 1)\n\n        # theta\n        theta_x = self.theta(x1).view(b, self.mc, -1)\n        theta_x_s = theta_x.permute(0, 2, 1)\n\n        # phi x\n        phi_x = self.phi(x1).view(b, self.mc, -1)\n        phi_x_s = phi_x\n\n        # non-local attention\n        f_s = torch.matmul(theta_x_s, phi_x_s)\n        f_s_div = F.softmax(f_s, dim=-1)\n\n        # get y_s\n        y_s = torch.matmul(f_s_div, g_x)\n        y_s = y_s.permute(0, 2, 1).contiguous()\n        y_s = y_s.view(b, c, h, w)\n\n        # GX: (256,256,18), output mask for the deep metric loss.\n        mask_feat = x + self.gamma_s * self.W_s(y_s)\n\n        # get 1-dimensional mask_tmp\n        mask_binary = torch.sigmoid(self.getmask(mask_feat))\n        mask_tmp = mask_binary.repeat(1, 3, 1, 1)\n        mask_img = img * mask_tmp # mask_img is the overlaid image.\n\n        ## conv output\n        x, new_mask = self.Pconv_1(mask_img, mask_tmp)\n        x, new_mask = self.Pconv_2(x, new_mask)\n        x, _        = self.Pconv_3(x, new_mask)\n        mask_binary = mask_binary.squeeze(dim=1)\n        return x, mask_feat, mask_binary\n\nclass Flatten(nn.Module):\n    def __init__(self):\n        super(Flatten, self).__init__()\n        \n    def forward(self, x):       \n        return x.view(x.size(0), -1)\n\nclass Classifer(nn.Module):\n    def __init__(self, in_channels, output_channels):\n        super(Classifer, self).__init__()\n        self.pool = nn.Sequential(\n                                  # nn.AdaptiveAvgPool2d((1,1)),\n                                  nn.AdaptiveAvgPool2d(1),\n                                  Flatten()\n                                )\n        self.fc = nn.Linear(in_channels, output_channels, bias=True)\n        self.relu = nn.ReLU(inplace=True)\n\n    def forward(self, x):\n        feat = self.pool(x)\n        feat = self.relu(feat)\n        cls_res = self.fc(feat)\n        return cls_res\n\nclass BranchCLS(nn.Module):\n    def __init__(self, in_channels, output_channels):\n        super(BranchCLS, self).__init__()\n        self.pool = nn.Sequential(nn.AdaptiveAvgPool2d((1,1)),\n                                  Flatten()\n                                )\n        self.fc = nn.Linear(18, output_channels, bias=True)\n        self.bn = nn.BatchNorm1d(18, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n        self.branch_cls = nn.Sequential(nn.Conv2d(in_channels=in_channels, out_channels=32, \n                                                  padding=1, kernel_size=3, stride=1),\n                                        nn.ReLU(inplace=True),\n                                        nn.Conv2d(in_channels=32, out_channels=18,\n                                                  padding=1, kernel_size=3, stride=1),\n                                        nn.ReLU(inplace=True), \n                                        )\n        self.leakyrelu = nn.LeakyReLU(0.2)\n\n    def forward(self, x):\n        feat = self.branch_cls(x)\n        x = self.pool(feat)\n        x = self.bn(x)\n        cls_res = self.fc(x)\n        cls_pro = self.leakyrelu(cls_res)\n        zero_vec = -9e15*torch.ones_like(cls_pro)\n        cls_pro  = torch.where(cls_pro > 0, cls_pro, zero_vec)\n        return cls_res, cls_pro, feat\n\nclass NLCDetection(nn.Module):\n    def __init__(self, args):\n        super(NLCDetection, self).__init__()\n        self.crop_size = args.crop_size\n        self.split_tensor_1 = torch.tensor([1, 3]).cuda()\n        self.split_tensor_2 = torch.tensor([1, 2, 1, 3]).cuda()\n        self.softmax_m = nn.Softmax(dim=1)\n        FENet_cfg = get_cfg_defaults()\n        feat1_num, feat2_num, feat3_num, feat4_num = FENet_cfg['STAGE4']['NUM_CHANNELS']\n\n        ## mask generation branch.\n        self.getmask = NonLocalMask(feat1_num, 4)\n\n        ## classification branch.\n        self.branch_cls_level_1 = BranchCLS(271, 14)   # 252 + 18 = 270\n        self.branch_cls_level_2 = BranchCLS(252, 7)    # 144+72+36 = 252\n        self.branch_cls_level_3 = BranchCLS(216, 5)    # 144+72 = 216\n        self.branch_cls_level_4 = BranchCLS(144, 3)    # 144\n\n    def forward(self, feat, img):\n        s1, s2, s3, s4 = feat\n\n        # mask_binary is intermediate result, to ignore.\n        pconv_feat, mask, mask_binary = self.getmask(s1, img)\n        pconv_feat = pconv_feat.clone().detach()\n\n        pconv_1 = F.interpolate(pconv_feat, size=s1.size()[2:], mode='bilinear', align_corners=True)\n\n        ## forth branch.\n        cls_4, pro_4, _ = self.branch_cls_level_4(s4)\n        cls_prob_4      = self.softmax_m(pro_4)\n        cls_prob_40 = torch.unsqueeze(cls_prob_4[:,0],1)\n        cls_prob_41 = torch.unsqueeze(cls_prob_4[:,1],1)\n        cls_prob_42 = torch.unsqueeze(cls_prob_4[:,2],1)\n        cls_prob_mask_3 = torch.cat([cls_prob_40, cls_prob_41, cls_prob_41, cls_prob_42, cls_prob_42],axis=1)\n\n        ## third branch\n        s4F = F.interpolate(s4, size=s3.size()[2:], mode='bilinear', align_corners=True)\n        s3_input = torch.cat([s4F, s3], axis=1)\n        cls_3, pro_3, _ = self.branch_cls_level_3(s3_input)\n        cls_prob_3      = self.softmax_m(pro_3)\n        cls_3 = cls_3 + cls_3 * cls_prob_mask_3\n        cls_prob_30 = torch.unsqueeze(cls_prob_3[:,0],1)\n        cls_prob_31 = torch.unsqueeze(cls_prob_3[:,1],1)\n        cls_prob_32 = torch.unsqueeze(cls_prob_3[:,2],1)\n        cls_prob_33 = torch.unsqueeze(cls_prob_3[:,3],1)\n        cls_prob_34 = torch.unsqueeze(cls_prob_3[:,4],1)\n        cls_prob_mask_2 = torch.cat([cls_prob_30, cls_prob_31, cls_prob_31, \n                                     cls_prob_32, cls_prob_32,\n                                     cls_prob_33, cls_prob_34],axis=1)\n\n        ## second branch\n        s3F = F.interpolate(s3_input, size=s2.size()[2:], mode='bilinear', align_corners=True)\n        s2_input = torch.cat([s3F, s2], axis=1)\n        cls_2, pro_2, _ = self.branch_cls_level_2(s2_input) \n        cls_prob_2      = self.softmax_m(pro_2)\n        cls_2 = cls_2 + cls_2 * cls_prob_mask_2\n        cls_prob_20 = torch.unsqueeze(cls_prob_2[:,0],1)\n        cls_prob_21 = torch.unsqueeze(cls_prob_2[:,1],1)\n        cls_prob_22 = torch.unsqueeze(cls_prob_2[:,2],1)\n        cls_prob_23 = torch.unsqueeze(cls_prob_2[:,3],1)\n        cls_prob_24 = torch.unsqueeze(cls_prob_2[:,4],1)\n        cls_prob_25 = torch.unsqueeze(cls_prob_2[:,4],1)\n        cls_prob_26 = torch.unsqueeze(cls_prob_2[:,4],1)\n        cls_prob_mask_1 = torch.cat([cls_prob_20, \n                                     cls_prob_21, cls_prob_21, cls_prob_22, cls_prob_22,    # 4 diffusion\n                                     cls_prob_23, cls_prob_23, cls_prob_24, cls_prob_24,    # 4 gan\n                                     cls_prob_25, cls_prob_25,                              # faceshifter+stgan\n                                     cls_prob_26, cls_prob_26, cls_prob_26], axis=1)        # 3 editing\n\n        s2F = F.interpolate(s2_input, size=s1.size()[2:], mode='bilinear', align_corners=True)\n        s1_input = torch.cat([s2F, s1, pconv_1], axis=1)\n        cls_1, pro_1, _ = self.branch_cls_level_1(s1_input) \n        cls_1 = cls_1 + cls_1 * cls_prob_mask_1\n        return mask, mask_binary, cls_4, cls_3, cls_2, cls_1"
  },
  {
    "path": "models/seg_hrnet.py",
    "content": "# ------------------------------------------------------------------------------\n# Copyright (c) Microsoft\n# Licensed under the MIT License.\n# The script is adopted from Ke Sun (sunk@mail.ustc.edu.cn)\n# ------------------------------------------------------------------------------\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\nfrom .LaPlacianMs import LaPlacianMs\nfrom .NLCDetection_pconv import weights_init\n\nimport os\nimport logging\nimport functools\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch._utils\nimport torch.nn.functional as F\n\nBN_MOMENTUM = 0.01\nlogger = logging.getLogger(__name__)\n\n# noise generation\ndef srm_generation(image):\n    \"\"\"\n    :param image: N * C * H * W\n    :return: noises\n    \"\"\"\n\n    # srm kernel 1\n    srm1 = np.zeros([5, 5]).astype('float32')\n    srm1[1:-1, 1:-1] = np.array([[-1, 2, -1],\n                                 [2, -4, 2],\n                                 [-1, 2, -1]])\n    srm1 /= 4.\n    # srm kernel 2\n    srm2 = np.array([[-1, 2, -2, 2, -1],\n                     [2, -6, 8, -6, 2],\n                     [-2, 8, -12, 8, -2],\n                     [2, -6, 8, -6, 2],\n                     [-1, 2, -2, 2, -1]]).astype('float32')\n    srm2 /= 12.\n    # srm kernel 3\n    srm3 = np.zeros([5, 5]).astype('float32')\n    srm3[2, 1:-1] = np.array([1, -2, 1])\n    srm3 /= 2.\n\n    srm = np.stack([srm1, srm2, srm3], axis=0)\n\n    W_srm = np.zeros([3, 3, 5, 5]).astype('float32')\n\n    for i in range(3):\n        W_srm[i, 0, :, :] = srm[i, :, :]\n        W_srm[i, 1, :, :] = srm[i, :, :]\n        W_srm[i, 2, :, :] = srm[i, :, :]\n\n    W_srm = torch.from_numpy(W_srm).to(image.get_device())\n\n    srm_noise = F.conv2d(image, W_srm, padding=2)\n\n    return srm_noise\n\n# bayar constrained layer\nclass BayarConstraint(object):\n    def __init__(self):\n        pass\n\n    def __call__(self, module):\n        if hasattr(module, 'weight'):\n            weight = module.weight.data      # oc, ic, h, w\n\n            h, w = weight.size()[2:]\n            mask = torch.zeros_like(weight)\n            mask[:, :, h//2, w//2] = 1\n\n            weight *= (1 - mask)\n            rest_sum = torch.sum(weight, dim=(2, 3), keepdim=True)\n            weight /= (rest_sum + 1e-7)\n            weight -= mask\n            module.weight.data = weight\n\ndef conv3x3(in_planes, out_planes, stride=1):\n    \"\"\"3x3 convolution with padding\"\"\"\n    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,\n                     padding=1, bias=False)\n\nclass CatDepth(nn.Module):\n    def __init__(self):\n        super(CatDepth, self).__init__()\n\n    def forward(self, x, y):\n        return torch.cat([x,y],dim=1)\n\n'''GX: basicblock contains two conv3x3 and two batch norm'''\n'''GX: at last, it has a residual connection'''\nclass BasicBlock(nn.Module):\n    expansion = 1\n\n    def __init__(self, inplanes, planes, stride=1, downsample=None):\n        super(BasicBlock, self).__init__()\n        self.conv1 = conv3x3(inplanes, planes, stride)\n        self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)\n        self.relu = nn.ReLU(inplace=False)\n        self.conv2 = conv3x3(planes, planes)\n        self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)\n        self.downsample = downsample\n        self.stride = stride\n\n    def forward(self, x):\n        residual = x\n\n        out = self.conv1(x)\n        out = self.bn1(out)\n        out = self.relu(out)\n\n        out = self.conv2(out)\n        out = self.bn2(out)\n\n        if self.downsample is not None:\n            residual = self.downsample(x)\n\n        out = out + residual\n        out = self.relu(out)\n\n        return out\n\n'''GX: 3 conv + 3 bn then a residual.'''\nclass Bottleneck(nn.Module):\n    expansion = 4\n\n    def __init__(self, inplanes, planes, stride=1, downsample=None):\n        super(Bottleneck, self).__init__()\n        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)\n        self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)\n        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,\n                               padding=1, bias=False)\n        self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)\n        self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1,\n                               bias=False)\n        self.bn3 = nn.BatchNorm2d(planes * self.expansion,\n                                  momentum=BN_MOMENTUM)\n        self.relu = nn.ReLU(inplace=False)\n        self.downsample = downsample\n        self.stride = stride\n\n    def forward(self, x):\n        residual = x\n\n        out = self.conv1(x)\n        out = self.bn1(out)\n        out = self.relu(out)\n\n        out = self.conv2(out)\n        out = self.bn2(out)\n        out = self.relu(out)\n\n        out = self.conv3(out)\n        out = self.bn3(out)\n\n        if self.downsample is not None:\n            residual = self.downsample(x)\n\n        out = out + residual\n        out = self.relu(out)\n\n        return out\n\n'''GX: the basic component in the network.'''\nclass HighResolutionModule(nn.Module):\n    def __init__(self, num_branches, blocks, num_blocks, num_inchannels,\n                 num_channels, fuse_method, multi_scale_output=True):\n        super(HighResolutionModule, self).__init__()\n        self._check_branches(\n            num_branches, blocks, num_blocks, num_inchannels, num_channels)\n\n        self.num_inchannels = num_inchannels\n        self.fuse_method = fuse_method\n        self.num_branches = num_branches\n\n        self.multi_scale_output = multi_scale_output\n\n        self.branches = self._make_branches(\n            num_branches, blocks, num_blocks, num_channels)\n        self.fuse_layers = self._make_fuse_layers()\n        self.relu = nn.ReLU(inplace=False)\n\n    def _check_branches(self, num_branches, blocks, num_blocks,\n                        num_inchannels, num_channels):\n        if num_branches != len(num_blocks):\n            error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format(\n                num_branches, len(num_blocks))\n            raise ValueError(error_msg)\n\n        if num_branches != len(num_channels):\n            error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format(\n                num_branches, len(num_channels))\n            raise ValueError(error_msg)\n\n        if num_branches != len(num_inchannels):\n            error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format(\n                num_branches, len(num_inchannels))\n            raise ValueError(error_msg)\n\n    def _make_one_branch(self, branch_index, block, num_blocks, num_channels,\n                         stride=1):\n        downsample = None\n        if stride != 1 or \\\n                self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion:\n            downsample = nn.Sequential(\n                nn.Conv2d(self.num_inchannels[branch_index],\n                          num_channels[branch_index] * block.expansion,\n                          kernel_size=1, stride=stride, bias=False),\n                nn.BatchNorm2d(num_channels[branch_index] * block.expansion,\n                               momentum=BN_MOMENTUM),\n            )\n\n        layers = []\n        layers.append(block(self.num_inchannels[branch_index],\n                            num_channels[branch_index], stride, downsample))\n        self.num_inchannels[branch_index] = \\\n            num_channels[branch_index] * block.expansion\n        for i in range(1, num_blocks[branch_index]):\n            layers.append(block(self.num_inchannels[branch_index],\n                                num_channels[branch_index]))\n\n        return nn.Sequential(*layers)\n\n    def _make_branches(self, num_branches, block, num_blocks, num_channels):\n        branches = []\n\n        for i in range(num_branches):\n            branches.append(\n                self._make_one_branch(i, block, num_blocks, num_channels))\n\n        return nn.ModuleList(branches)\n\n    ## GX: fuse layer converts feature maps at different resolution branches\n    ## GX: into the feature map of the new branches' feature map.\n    ## GX: https://zhuanlan.zhihu.com/p/335333233\n    def _make_fuse_layers(self):\n        if self.num_branches == 1:\n            return None\n\n        num_branches = self.num_branches\n        num_inchannels = self.num_inchannels\n        fuse_layers = []\n        for i in range(num_branches if self.multi_scale_output else 1):\n            fuse_layer = []\n            for j in range(num_branches):\n                if j > i:\n                    fuse_layer.append(nn.Sequential(\n                        nn.Conv2d(num_inchannels[j], num_inchannels[i], 1, 1, 0, bias=False),\n                        nn.BatchNorm2d(num_inchannels[i], momentum=BN_MOMENTUM)))\n                elif j == i:\n                    fuse_layer.append(None)\n                else:\n                    conv3x3s = []\n                    for k in range(i - j):\n                        if k == i - j - 1:\n                            num_outchannels_conv3x3 = num_inchannels[i]\n                            conv3x3s.append(nn.Sequential(\n                                nn.Conv2d(num_inchannels[j],\n                                          num_outchannels_conv3x3,\n                                          3, 2, 1, bias=False),\n                                nn.BatchNorm2d(num_outchannels_conv3x3,\n                                               momentum=BN_MOMENTUM)))\n                        else:\n                            num_outchannels_conv3x3 = num_inchannels[j]\n                            conv3x3s.append(nn.Sequential(\n                                nn.Conv2d(num_inchannels[j],\n                                          num_outchannels_conv3x3,\n                                          3, 2, 1, bias=False),\n                                nn.BatchNorm2d(num_outchannels_conv3x3,\n                                               momentum=BN_MOMENTUM),\n                                nn.ReLU(inplace=False)))\n                    fuse_layer.append(nn.Sequential(*conv3x3s))\n            fuse_layers.append(nn.ModuleList(fuse_layer))\n\n        return nn.ModuleList(fuse_layers)\n\n    def get_num_inchannels(self):\n        return self.num_inchannels\n\n    def forward(self, x):\n        if self.num_branches == 1:\n            return [self.branches[0](x[0])]\n\n        for i in range(self.num_branches):\n            x[i] = self.branches[i](x[i])\n\n        x_fuse = []\n        for i in range(len(self.fuse_layers)):\n            y = x[0] if i == 0 else self.fuse_layers[i][0](x[0])\n            for j in range(1, self.num_branches):\n                if i == j:\n                    y = y + x[j]\n                elif j > i:\n                    width_output = x[i].shape[-1]\n                    height_output = x[i].shape[-2]\n                    y = y + F.interpolate(\n                        self.fuse_layers[i][j](x[j]),\n                        size=[height_output, width_output],\n                        mode='bilinear', align_corners=True)\n                else:\n                    y = y + self.fuse_layers[i][j](x[j])\n            x_fuse.append(self.relu(y))\n\n        return x_fuse\n\n\nblocks_dict = {\n    'BASIC': BasicBlock,\n    'BOTTLENECK': Bottleneck\n}\n\n## GX: the HighResolutionNet has 4 stages. \n## GX: each stage has one module which is HighResolutionModule.\n## GX: HighResolutionModule has 1,2,3,4 branches.\n## GX: each stage has a transitional layers in between.\nclass HighResolutionNet(nn.Module):\n\n    def __init__(self, config, **kwargs):\n        super(HighResolutionNet, self).__init__()\n\n        # noise conv\n        # self.im_conv = nn.Conv2d(3, 10, kernel_size=3, stride=1, padding=1, bias=False)\n        # self.bayar_conv = nn.Conv2d(3, 3, kernel_size=5, stride=1, padding=2, bias=False)\n        # self.constraints = BayarConstraint()\n\n        # stem net\n        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)\n        self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)\n        self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False)\n        self.bn2 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)\n        self.relu = nn.ReLU(inplace=False)\n\n        # frequency branch\n        self.conv1fre = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)\n        self.bn1fre = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)\n        self.conv2fre = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False)\n        self.bn2fre = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)\n        self.laplacian = LaPlacianMs(in_c=64,gauss_ker_size=3,scale=[2,4,8])\n\n        # concat\n        self.concat_depth = CatDepth()\n        self.conv_1x1_merge = nn.Sequential(nn.Conv2d(128, 64,\n                                                  kernel_size=1, stride=1,\n                                                  bias=False,groups=2),\n                                        nn.BatchNorm2d(64),\n                                        nn.ReLU(inplace=True),\n                                        nn.Dropout(p=0.2)\n                                       )\n        # self.module_initializer = module_initializer()\n        # self.conv_1x1_merge = self.module_initializer(self.conv_1x1_merge)\n        self.conv_1x1_merge.apply(weights_init('kaiming'))\n\n        self.stage1_cfg = config['STAGE1']\n        num_channels = self.stage1_cfg['NUM_CHANNELS'][0]\n        block = blocks_dict[self.stage1_cfg['BLOCK']]\n        num_blocks = self.stage1_cfg['NUM_BLOCKS'][0]\n        self.layer1 = self._make_layer(block, 64, num_channels, num_blocks)\n        stage1_out_channel = block.expansion * num_channels\n\n        self.stage2_cfg = config['STAGE2']\n        num_channels = self.stage2_cfg['NUM_CHANNELS']\n        block = blocks_dict[self.stage2_cfg['BLOCK']]\n        num_channels = [\n            num_channels[i] * block.expansion for i in range(len(num_channels))]\n        self.transition1 = self._make_transition_layer(\n            [stage1_out_channel], num_channels)\n        self.stage2, pre_stage_channels = self._make_stage(\n            self.stage2_cfg, num_channels)\n\n        self.stage3_cfg = config['STAGE3']\n        num_channels = self.stage3_cfg['NUM_CHANNELS']\n        block = blocks_dict[self.stage3_cfg['BLOCK']]\n        num_channels = [\n            num_channels[i] * block.expansion for i in range(len(num_channels))]\n        self.transition2 = self._make_transition_layer(\n            pre_stage_channels, num_channels)\n        self.stage3, pre_stage_channels = self._make_stage(\n            self.stage3_cfg, num_channels)\n\n        self.stage4_cfg = config['STAGE4']\n        num_channels = self.stage4_cfg['NUM_CHANNELS']\n        block = blocks_dict[self.stage4_cfg['BLOCK']]\n        num_channels = [\n            num_channels[i] * block.expansion for i in range(len(num_channels))]\n        self.transition3 = self._make_transition_layer(\n            pre_stage_channels, num_channels)\n        self.stage4, pre_stage_channels = self._make_stage(\n            self.stage4_cfg, num_channels, multi_scale_output=True)\n\n        last_inp_channels = np.int(np.sum(pre_stage_channels))\n\n    ## GX: one dimension matrix converts pre to pos.\n    ## GX: if channel numbers are equal, pass it directly.\n    ## GX: if channel numbers are different, using conv 3x3.\n    ## GX: https://zhuanlan.zhihu.com/p/335333233\n    def _make_transition_layer(self, num_channels_pre_layer, num_channels_cur_layer):\n        num_branches_cur = len(num_channels_cur_layer)\n        num_branches_pre = len(num_channels_pre_layer)\n\n        transition_layers = []\n        for i in range(num_branches_cur):\n            if i < num_branches_pre:\n                if num_channels_cur_layer[i] != num_channels_pre_layer[i]:\n                    transition_layers.append(nn.Sequential(\n                        nn.Conv2d(num_channels_pre_layer[i],\n                                  num_channels_cur_layer[i],\n                                  3,\n                                  1,\n                                  1,\n                                  bias=False),\n                        nn.BatchNorm2d(\n                            num_channels_cur_layer[i], momentum=BN_MOMENTUM),\n                        nn.ReLU(inplace=False)))\n                else:\n                    transition_layers.append(None)\n            else:\n                conv3x3s = []\n                for j in range(i + 1 - num_branches_pre):\n                    inchannels = num_channels_pre_layer[-1]\n                    outchannels = num_channels_cur_layer[i] \\\n                        if j == i - num_branches_pre else inchannels\n                    conv3x3s.append(nn.Sequential(\n                        nn.Conv2d(\n                            inchannels, outchannels, 3, 2, 1, bias=False),\n                        nn.BatchNorm2d(outchannels, momentum=BN_MOMENTUM),\n                        nn.ReLU(inplace=False)))\n                transition_layers.append(nn.Sequential(*conv3x3s))\n\n        return nn.ModuleList(transition_layers)\n\n    ## GX: _make_layer creates a conv + bn\n    def _make_layer(self, block, inplanes, planes, blocks, stride=1):\n        downsample = None\n        if stride != 1 or inplanes != planes * block.expansion:\n            downsample = nn.Sequential(\n                nn.Conv2d(inplanes, planes * block.expansion,\n                          kernel_size=1, stride=stride, bias=False),\n                nn.BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM),\n            )\n\n        layers = []\n        layers.append(block(inplanes, planes, stride, downsample))\n        inplanes = planes * block.expansion\n        for i in range(1, blocks):\n            layers.append(block(inplanes, planes))\n\n        return nn.Sequential(*layers)\n\n    def _make_stage(self, layer_config, num_inchannels, multi_scale_output=True):\n        ## GX: num_modules are all 1 in this work.\n        ## GX: light-weight architectures: num_blocks are all 0.\n        ## GX: branch numbers are 2, 3, 4.\n        num_modules = layer_config['NUM_MODULES'] \n        num_branches = layer_config['NUM_BRANCHES']\n        num_blocks = layer_config['NUM_BLOCKS']\n        num_channels = layer_config['NUM_CHANNELS']\n        block = blocks_dict[layer_config['BLOCK']]\n        fuse_method = layer_config['FUSE_METHOD']\n\n        modules = []\n        for i in range(num_modules):\n            # multi_scale_output is only used last module\n            if not multi_scale_output and i == num_modules - 1:\n                reset_multi_scale_output = False\n            else:\n                reset_multi_scale_output = True\n            modules.append(\n                HighResolutionModule(num_branches, block, num_blocks,\n                                     num_inchannels, num_channels, fuse_method,\n                                     reset_multi_scale_output)\n            )\n            num_inchannels = modules[-1].get_num_inchannels()\n\n        return nn.Sequential(*modules), num_inchannels\n\n    def forward(self, x):\n        x_fre = self.conv1fre(x)\n        x_fre = self.bn1fre(x_fre)\n        x_fre = self.relu(x_fre)\n        x_fre = self.laplacian(x_fre)\n        x_fre = self.conv2fre(x_fre)\n        x_fre = self.bn2fre(x_fre)\n        x_fre = self.relu(x_fre)\n\n        x = self.conv1(x)\n        x = self.bn1(x)\n        x = self.relu(x)\n        x = self.conv2(x)\n        x = self.bn2(x)\n        x = self.relu(x)\n        x = self.concat_depth(x, x_fre)\n        x = self.conv_1x1_merge(x)\n        x = self.layer1(x)  \n\n        x_list = []\n        for i in range(self.stage2_cfg['NUM_BRANCHES']):\n            if self.transition1[i] is not None:\n                x_list.append(self.transition1[i](x))\n            else:\n                x_list.append(x)\n        y_list = self.stage2(x_list)\n        x_list = []\n        for i in range(self.stage3_cfg['NUM_BRANCHES']):\n            if self.transition2[i] is not None:\n                if i < self.stage2_cfg['NUM_BRANCHES']:\n                    x_list.append(self.transition2[i](y_list[i]))\n                else:\n                    x_list.append(self.transition2[i](y_list[-1]))\n            else:\n                x_list.append(y_list[i])\n        y_list = self.stage3(x_list)\n        x_list = []\n        for i in range(self.stage4_cfg['NUM_BRANCHES']):\n            if self.transition3[i] is not None:\n                if i < self.stage3_cfg['NUM_BRANCHES']:\n                    x_list.append(self.transition3[i](y_list[i]))\n                else:\n                    x_list.append(self.transition3[i](y_list[-1]))\n            else:\n                x_list.append(y_list[i])\n        x = self.stage4(x_list)\n        return x\n\n    def init_weights(self, pretrained='',):\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                nn.init.normal_(m.weight, std=0.001)\n            elif isinstance(m, nn.BatchNorm2d):\n                nn.init.constant_(m.weight, 1)\n                nn.init.constant_(m.bias, 0)\n        if os.path.isfile(pretrained):\n            pretrained_dict = torch.load(pretrained)    \n            ## GX: official pre-trained dict.\n            print('=> loading HRNet pretrained model {}'.format(pretrained))\n            model_dict = self.state_dict()  \n            model_pretrained_lst, model_nopretrained_lst = [], []\n            ## GX: model_dict is weights from the current architecture.\n            pretrained_dict_used = {}\n            ## GX: gather weights from pretrained_dict to model_dict.\n            nopretrained_dict = {k: v for k, v in model_dict.items()}\n\n            for k, v in model_dict.items():\n                pretrained_key = 'model.' + k\n                if pretrained_key not in pretrained_dict.keys():\n                    if 'stage2' in pretrained_key and 'fuse_layers' not in pretrained_key:\n                        if 'branches.2' in pretrained_key:\n                            pretrained_key = pretrained_key.replace('stage2.0.', 'stage3.0.')\n                        elif 'branches.3' in pretrained_key:\n                            pretrained_key = pretrained_key.replace('stage2.0.', 'stage4.0.')\n                    elif 'stage3' in pretrained_key and 'fuse_layers' not in pretrained_key:\n                        pretrained_key = pretrained_key.replace('stage3.0.', 'stage4.0.')\n                    elif 'fre' in pretrained_key:\n                        pretrained_key = pretrained_key.replace('fre', '')\n                if pretrained_key in pretrained_dict.keys():\n                    pretrained_dict_used[k] = pretrained_dict[pretrained_key]\n                    nopretrained_dict.pop(k)\n            print(\"no pretrain dict length is: \", len(nopretrained_dict))\n            print(\"pretrained dict length is: \", len(pretrained_dict))\n            model_dict.update(pretrained_dict_used)\n            self.load_state_dict(model_dict)\n\ndef get_seg_model(cfg, **kwargs):\n    model = HighResolutionNet(cfg, **kwargs)\n    model.init_weights(cfg.PRETRAINED)\n    return model"
  },
  {
    "path": "models/seg_hrnet_config.py",
    "content": "# ------------------------------------------------------------------------------\n# Copyright (c) Microsoft\n# Licensed under the MIT License.\n# The script is adopted from Ke Sun (sunk@mail.ustc.edu.cn)\n# ------------------------------------------------------------------------------\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nfrom yacs.config import CfgNode as CN\n\n# high_resoluton_net related params for segmentation\nHRNET = CN()\nHRNET.PRETRAINED_LAYERS = ['*']\nHRNET.STEM_INPLANES = 64\nHRNET.FINAL_CONV_KERNEL = 1\nHRNET.PRETRAINED = 'models/hrnet_w18_small_v2.pth'\n\nHRNET.STAGE1 = CN()\nHRNET.STAGE1.NUM_MODULES = 1\nHRNET.STAGE1.NUM_BRANCHES = 1\nHRNET.STAGE1.NUM_BLOCKS = [2]\nHRNET.STAGE1.NUM_CHANNELS = [64]\nHRNET.STAGE1.BLOCK = 'BOTTLENECK'\nHRNET.STAGE1.FUSE_METHOD = 'SUM'\n\nHRNET.STAGE2 = CN()\nHRNET.STAGE2.NUM_MODULES = 1\nHRNET.STAGE2.NUM_BRANCHES = 4\nHRNET.STAGE2.NUM_BLOCKS = [2, 2, 2, 2]\nHRNET.STAGE2.NUM_CHANNELS = [18, 36, 72, 144]\nHRNET.STAGE2.BLOCK = 'BASIC'\nHRNET.STAGE2.FUSE_METHOD = 'SUM'\n\nHRNET.STAGE3 = CN()\nHRNET.STAGE3.NUM_MODULES = 1\nHRNET.STAGE3.NUM_BRANCHES = 4\nHRNET.STAGE3.NUM_BLOCKS = [2, 2, 2, 2]\nHRNET.STAGE3.NUM_CHANNELS = [18, 36, 72, 144]\nHRNET.STAGE3.BLOCK = 'BASIC'\nHRNET.STAGE3.FUSE_METHOD = 'SUM'\n\nHRNET.STAGE4 = CN()\nHRNET.STAGE4.NUM_MODULES = 1\nHRNET.STAGE4.NUM_BRANCHES = 4\nHRNET.STAGE4.NUM_BLOCKS = [2, 2, 2, 2]\nHRNET.STAGE4.NUM_CHANNELS = [18, 36, 72, 144]\nHRNET.STAGE4.BLOCK = 'BASIC'\nHRNET.STAGE4.FUSE_METHOD = 'SUM'\n\n\ndef get_cfg_defaults():\n  \"\"\"Get a yacs CfgNode object with default values for my_project.\"\"\"\n  # Return a clone so that the defaults will not be altered\n  # This is for the \"local variable\" use pattern\n  return HRNET.clone()"
  },
  {
    "path": "utils/custom_loss.py",
    "content": "# ------------------------------------------------------------------------------\n# Author: Xiao Guo (guoxia11@msu.edu)\n# CVPR2023: Hierarchical Fine-Grained Image Forgery Detection and Localization\n# ------------------------------------------------------------------------------\nimport os\nimport time\nimport torch\nimport torch.nn as nn\nfrom tqdm import tqdm\n\ndevice = torch.device('cuda:0')\ndevice_ids = [0]\n\nclass IsolatingLossFunction(torch.nn.Module):\n\tdef __init__(self, c, R, p=2, threshold_val=1.85):\n\t\tsuper().__init__()\n\t\tself.c = c.clone().detach() # Center of the hypershpere, c ∈ ℝ^d (d-dimensional real-valued vector)\n\t\tself.R = R.clone().detach() # Radius of the hypersphere, R ∈ ℝ^1 (Real-valued)\n\t\tself.p = p                  # norm value (p-norm), p ∈ ℝ^1 (Default 2)\n\t\tself.margin_natu = (0.15)*self.R    \n\t\tself.margin_mani = (2.5)*self.R\n\t\tself.threshold   = threshold_val*self.R\n\n\t\tprint('\\n')\n\t\tprint('$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$')\n\t\tprint('$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$')\n\t\tprint(f'The Radius manipul is {self.margin_natu}.')\n\t\tprint(f'The Radius expansn is {self.margin_mani}.')\n\t\tprint(f'The Radius threshold is {self.threshold}.')\n\t\tprint('$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$')\n\t\tprint('$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$')\n\t\tprint('\\n')\n\t\tself.pdist = torch.nn.PairwiseDistance(p=self.p) # Creating a Pairwise Distance object\n\t\tself.dis_curBatch = 0\n\t\tself.dis = 0\n\n\tdef forward(self, model_output, label, threshold_new=None, update_flag=None):\n\t\t'''output the distance mask and compute the loss.'''\n\t\tbs, feat_dim, w, h = model_output.size()\n\t\tmodel_output = model_output.permute(0,2,3,1)\n\t\tmodel_output = torch.reshape(model_output, (-1, feat_dim))\n\t\tdist = self.pdist(model_output, self.c)\n\t\tself.dist = dist\n\t\tpred_mask = torch.gt(self.dist, self.threshold).to(torch.float32)\n\t\tpred_mask = torch.reshape(pred_mask, (bs,w,h,1)).permute(0,3,1,2)\n\t\tself.dist = torch.reshape(self.dist, (bs,w,h,1)).permute(0,3,1,2)\n\t\tself.dis_curBatch = pred_mask.to(device).to(torch.float32)\n\n\t\tlabel = torch.reshape(label, (bs*w*h,1))\n\t\tlabel_sum = label.sum().item()\n\n\t\tlabel_nat  = torch.eq(label,0)\n\t\tlabel_mani = torch.eq(label,1)\n\t\tassert dist.size() == label_nat[:,0].size() \n\t\tassert dist.size() == label_mani[:,0].size()\n\n\t\tlabel_nat_sum  = label_nat.sum().item()\n\t\tlabel_mani_sum = label_mani.sum().item()\n\n\t\tdist_nat  = torch.masked_select(dist, label_nat[:,0])\n\t\tdist_mani = torch.max(torch.tensor(0).to(device).float(),\n\t\t\t\t\t\t\t  torch.sub(self.margin_mani, \n\t\t\t\t\t\t\t\t\t\ttorch.masked_select(dist, label_mani[:,0]))\n\t\t\t\t\t\t\t  )\n\n\t\tloss_nat  = dist_nat.sum()/label_nat_sum if label_nat_sum != 0 else \\\n\t\t\t\t\ttorch.tensor(0).to(device).float()\n\t\tloss_mani = dist_mani.sum()/label_mani_sum if label_mani_sum != 0 else \\\n\t\t\t\t\ttorch.tensor(0).to(device).float()\n\t\tloss_total = loss_nat + loss_mani\n\n\t\treturn loss_total.to(device), loss_mani.to(device), loss_nat.to(device)\n\n\tdef inference(self, model_output):\n\t\t'''output the distance for the final binary mask.'''\n\t\tbs, feat_dim, w, h = model_output.size()\n\t\tmodel_output = model_output.permute(0,2,3,1)\n\t\tmodel_output = torch.reshape(model_output, (-1, feat_dim))\n\t\tdist = self.pdist(model_output, self.c)\n\t\tself.dist = dist\n\t\tpred_mask = torch.gt(self.dist, self.threshold).to(torch.float32)\n\t\tpred_mask = torch.reshape(pred_mask, (bs,w,h,1)).permute(0,3,1,2)\n\t\tself.dist = torch.reshape(self.dist, (bs,w,h,1)).permute(0,3,1,2)\n\t\tself.dis_curBatch = pred_mask.to(device).to(torch.float32)\n\t\treturn self.dis_curBatch.squeeze(dim=1), self.dist.squeeze(dim=1)\n\ndef center_radius_init(args, FENet, SegNet, train_data_loader, debug=True, center=None):\n\t'''the center is the mean-value of pixel features of the real pixels'''\n\tsample_num = 0\n\tcenter = torch.zeros(18).to(device)\n\tFENet.eval()\n\tSegNet.eval()\n\twith torch.no_grad():\n\t\tfor batch_id, train_data in enumerate(tqdm(train_data_loader, desc=\"compute center\")):  \n\t\t\timage, masks, cls, fcls, scls, tcls = train_data\n\t\t\tif batch_id % 10 != 0:\n\t\t\t\tcontinue\n\t\t\tmask_cls = fcls.eq(0)\n\t\t\timage_selected = image[mask_cls,:]\n\t\t\tif image_selected.size()[0] == 0:\n\t\t\t\tcontinue\n\t\t\telse:\n\t\t\t\tsample_num += image_selected.size()[0]\n\t\t\tmask1 = masks[0].to(device)\n\t\t\timage_selected = image_selected.to(device)\n\t\t\tcls = cls.to(device)\n\t\t\tmask1_fea = FENet(image_selected)\n\t\t\tmask1_fea, _, _, _, _, _ = SegNet(mask1_fea, image_selected)\n\t\t\tmask1_fea = torch.mean(mask1_fea,(0,2,3))\n\t\t\tcenter += mask1_fea\n\n\tcenter = center/sample_num\n\tpdist  = torch.nn.PairwiseDistance(2)\n\tradius = torch.tensor(0, dtype=torch.float32).to(device)\n\twith torch.no_grad():\n\t\tfor batch_id, train_data in enumerate(tqdm(train_data_loader, desc=\"compute radius\")):  \n\t\t\tif batch_id % 10 != 0:\n\t\t\t\tcontinue    \n\t\t\timage, masks, cls, fcls, scls, tcls = train_data\n\t\t\tmask1 = masks[0].to(device)\n\t\t\timage = image.to(device)\n\t\t\tfcls = fcls.to(device)\n\t\t\tmask_cls = fcls.eq(0)\n\t\t\timage_selected = image[mask_cls,:]\n\t\t\tif image_selected.size()[0] == 0:\n\t\t\t\tcontinue\n\t\t\tmask1_fea = FENet(image_selected)\n\t\t\tmask1_fea, _, _, _, _, _ = SegNet(mask1_fea, image_selected)\n\t\t\tbs, channel, h, w = mask1_fea.size()\n\t\t\tmask1_fea = mask1_fea.permute(0,2,3,1)      \n\t\t\tmask1_fea = torch.reshape(mask1_fea, (bs*w*h, -1))\n\t\t\tdist = pdist(mask1_fea, center)\n\t\t\tdist_max = torch.max(dist)\n\t\t\tif radius < dist_max:\n\t\t\t\tradius = dist_max\n\treturn center, radius\n\ndef load_center_radius(args, FENet, SegNet, train_data_loader, center_radius_dir='center'):\n\t'''loading the pre-computed center and radius.'''\n\tcenter_radius_path = os.path.join(center_radius_dir, 'radius_center.pth')\n\tif os.path.exists(center_radius_path):\n\t\tload_dict_center_radius = torch.load(center_radius_path)\n\t\tcenter = load_dict_center_radius['center']\n\t\tradius = load_dict_center_radius['radius']\n\t\tcenter, radius = center.to(device), radius.to(device)\n\telse:\n\t\tos.makedirs(center_radius_dir, exist_ok=True)\n\t\tcenter, radius = center_radius_init(args, FENet, SegNet, train_data_loader, debug=True)\n\t\ttorch.save({'center': center, 'radius': radius}, center_radius_path)\n\treturn center, radius\n\ndef load_center_radius_api(center_radius_dir='center'):\n\t'''loading the pre-computed center and radius.'''\n\tcenter_radius_path = os.path.join(center_radius_dir, 'radius_center.pth')\n\tload_dict_center_radius = torch.load(center_radius_path)\n\tcenter = load_dict_center_radius['center']\n\tradius = load_dict_center_radius['radius']\n\tcenter, radius = center.to(device), radius.to(device)\n\treturn center, radius"
  },
  {
    "path": "utils/load_data.py",
    "content": "# ------------------------------------------------------------------------------\n# Author: Xiao Guo (guoxia11@msu.edu)\n# CVPR2023: Hierarchical Fine-Grained Image Forgery Detection and Localization\n# ------------------------------------------------------------------------------\nfrom os.path import isfile, join\nfrom PIL import Image\nfrom torchvision import transforms\nimport numpy as np\nimport abc\nimport cv2\nimport torch.utils.data as data\nimport torch.nn.functional as F\nimport random\nrandom.seed(1234567890)\nfrom random import randrange\nimport torch.nn as nn\nimport torch\nimport imageio\nimport time\nimport math\nimport torch\n\nclass BaseData(data.Dataset):\n    '''\n        The dataset used for the IFDL dataset.\n    '''\n    def __init__(self, args):\n        super(BaseData, self).__init__()\n        self.crop_size = args.crop_size\n        self.file_path = '/user/guoxia11/cvlshare/cvl-guoxia11/IMDL/REAL'\n        self.file_path_fake = '/user/guoxia11/cvlshare/cvl-guoxia11/IMDL/FAKE'     \n\n        # Real and Fake images.\n        self.image_names = []\n        self.image_class = self._img_list_retrieve()\n        for idx, _ in enumerate(self.image_class):\n            self.image_names += _\n\n    def __getitem__(self, index):\n        res = self.get_item(index)\n        return res\n\n    def __len__(self):\n        return len(self.image_names)\n\n    @abc.abstractmethod\n    def _img_list_retrieve():\n        pass\n\n    def _resize_func(self, input_img):\n        '''resize the input image into the crop size.'''\n        input_img = Image.fromarray(input_img)\n        input_img = input_img.resize(self.crop_size, resample=Image.BICUBIC)\n        input_img = np.asarray(input_img)\n        return input_img\n\n    def get_image(self, image_name, aug_index=None):\n        '''transform the image.'''\n        image = imageio.imread(image_name)\n        if image.shape[-1] == 4:\n            image = self.rgba2rgb(image)\n        image = self._resize_func(image)\n        image = image.astype(np.float32) / 255.\n        image = torch.from_numpy(image)\n        return image.permute(2, 0, 1)\n\n    def rgba2rgb(self, rgba, background=(255, 255, 255)):\n        '''\n            turn rgba to rgb.\n        '''\n        row, col, ch = rgba.shape\n        rgb = np.zeros((row, col, 3), dtype='float32')\n        r, g, b, a = rgba[:, :, 0], rgba[:, :, 1], rgba[:, :, 2], rgba[:, :, 3]\n        a = np.asarray(a, dtype='float32') / 255.0\n\n        R, G, B = background\n        rgb[:, :, 0] = r * a + (1.0 - a) * R\n        rgb[:, :, 1] = g * a + (1.0 - a) * G\n        rgb[:, :, 2] = b * a + (1.0 - a) * B\n\n        return np.asarray(rgb, dtype='uint8')\n\n    def generate_4masks(self, mask):\n        '''generate 4 masks at different scale.'''\n        crop_height, crop_width = self.crop_size\n        ma_height, ma_width = mask.shape[:2]\n        mask_pil = Image.fromarray(mask)\n\n        if ma_height != crop_height or ma_width != crop_width:\n            mask_pil = mask_pil.resize(self.crop_size, resample=Image.BICUBIC)\n\n        (width2, height2) = (mask_pil.width // 2, mask_pil.height // 2)\n        (width3, height3) = (mask_pil.width // 4, mask_pil.height // 4)\n        (width4, height4) = (mask_pil.width // 8, mask_pil.height // 8)\n\n        mask2 = mask_pil.resize((width2, height2))\n        mask3 = mask_pil.resize((width3, height3))\n        mask4 = mask_pil.resize((width4, height4))\n\n        mask = np.asarray(mask_pil)\n        mask = mask.astype(np.float32) / 255\n        mask[mask > 0.5] = 1\n        mask[mask <= 0.5] = 0\n\n        mask2 = np.asarray(mask2).astype(np.float32) / 255\n        mask2[mask2 > 0.5] = 1\n        mask2[mask2 <= 0.5] = 0\n\n        mask3 = np.asarray(mask3).astype(np.float32) / 255\n        mask3[mask3 > 0.5] = 1\n        mask3[mask3 <= 0.5] = 0\n\n        mask4 = np.asarray(mask4).astype(np.float32) / 255\n        mask4[mask4 > 0.5] = 1\n        mask4[mask4 <= 0.5] = 0\n\n        mask = torch.from_numpy(mask)\n        mask2 = torch.from_numpy(mask2)\n        mask3 = torch.from_numpy(mask3)\n        mask4 = torch.from_numpy(mask4)\n\n        # print(mask.size(), mask2.size(), mask3.size(), mask4.size())\n\n        return mask, mask2, mask3, mask4\n\n    def get_mask(self, image_name, cls, aug_index=None):\n        '''given the cls, we return the mask.'''\n        # authentic\n        if cls in [0,1,2,3,4]:\n            mask = self.load_mask('', real=True, aug_index=aug_index)\n            return_res = [0,0,0,0]\n        \n        # splice\n        elif cls == 5:\n            if '.jpg' in image_name:\n                mask_name = image_name.replace('fake', 'mask').replace('.jpg', '.png')\n            else:\n                mask_name = image_name.replace('fake', 'mask').replace('.tif', '.png')\n            mask = self.load_mask(mask_name, aug_index=aug_index)\n            return_res = [1,1,1,cls - 4]\n        \n        # inpainting\n        elif cls == 6:\n            mask_name = image_name.replace('/fake/', '/mask/').replace('.jpg', '.png')\n            mask = self.load_mask(mask_name, aug_index=aug_index)\n            return_res = [1,1,1,cls - 4]\n\n        # copy-move\n        elif cls == 7:\n            mask_name = image_name.replace('.png', '_mask.png')\n            mask_name = mask_name.replace('CopyMove', 'CopyMove_mask')\n            mask = self.load_mask(mask_name, aug_index=aug_index)\n            return_res = [1,1,1,cls - 4]\n\n        # faceshifter\n        elif cls == 8:  \n            image_id  = image_name.split('/')[-1].split('.')[0]\n            mask_name = image_name.replace(image_id, f'mask/{image_id}_mask')\n            mask = self.load_mask(mask_name, aug_index=aug_index)\n            return_res = [1,2,2,cls - 4]\n\n        # STGAN\n        elif cls == 9: \n            image_id  = image_name.split('/')[-1].split('.')[0]\n            mask_name = image_name.replace('fake', 'mask').replace(image_id, f'{image_id}_label')\n            mask = self.load_mask(mask_name, aug_index=aug_index)\n            return_res = [1,2,2,cls - 4]\n\n        ## they are star2, hisd, stylegan2, stylegan3, ddpm, ddim, latent, guided\n        elif cls in [10,11,12,13,14,15,16,17]:  \n            mask = self.load_mask('', real=False, full_syn=True, aug_index=aug_index)\n            if cls in [10,11]:\n                return_res = [2,3,3,cls-4]\n            elif cls in [12,13]:\n                return_res = [2,3,4,cls-4]\n            elif cls in [14,15]:\n                return_res = [2,4,5,cls-4]\n            elif cls in [16,17]:\n                return_res = [2,4,6,cls-4]\n        else:\n            print(cls, index)\n            raise Exception('class is not defined!')\n\n        return mask, return_res\n\n    def load_mask(self, mask_name, real=False, full_syn=False, gray=True, aug_index=None):\n        '''binarize the mask, given the mask_name.'''\n        if real:\n            mask = np.zeros(self.crop_size)\n        else:\n            if not full_syn:\n                mask = imageio.imread(mask_name) if not gray else np.asarray(Image.open(mask_name).convert('RGB').convert('L'))\n                mask = mask.astype(np.float32)\n            else:\n                mask = np.ones(self.crop_size)\n\n        mask = self.generate_4masks(mask)\n        return mask\n\n    def get_cls(self, image_name):\n        '''return the forgery/authentic cls given the image_name.'''\n        if '/authentic/' in image_name:\n            return_cls = 0\n        elif '/REAL/LSUN/' in image_name:\n            return_cls = 0\n        elif '/afhq_v2/' in image_name:\n            return_cls = 1\n        elif '/CelebAHQ/' in image_name:\n            return_cls = 2\n        elif '/FFHQ/' in image_name:\n            return_cls = 3\n        elif '/Youtube' in image_name:\n            return_cls = 4\n        elif '/splice' in image_name:\n            return_cls = 5\n        elif '/Inpainting' in image_name:\n            return_cls = 6\n        elif '/CopyMove' in image_name:\n            return_cls = 7\n        elif '/FaShifter' in image_name:\n            return_cls = 8\n        elif '/STGAN' in image_name:\n            return_cls = 9\n        elif '/Star2' in image_name:\n            return_cls = 10\n        elif '/HiSD' in image_name:\n            return_cls = 11\n        elif '/STYL2' in image_name:\n            return_cls = 12\n        elif '/STYL3' in image_name:\n            return_cls = 13\n        elif '/DDPM_' in image_name:\n            return_cls = 14\n        elif '/DDIM_' in image_name:\n            return_cls = 15\n        elif '/D_latent' in image_name:\n            return_cls = 16\n        elif '/GLIDE/' in image_name:\n            return_cls = 17\n        else:\n            print(image_name)\n            raise ValueError \n        return return_cls\n\nclass TrainData(BaseData):\n    '''\n        The dataset used for the IFDL dataset.\n    '''\n    def __init__(self, args):\n        self.is_train = True\n        self.val_num  = 90000\n        super(TrainData, self).__init__(args)\n\n    def img_retrieve(self, file_text, file_folder, real=True):\n        '''\n            Parameters:\n                file_text: str, text file for images.\n                file_folder: str, images folder.\n            Returns:\n                the image list.\n        '''\n        result_list = []\n        val_num   = self.val_num * 3 if file_text in [\"Youtube\", \"FaShifter\"] else self.val_num\n        data_path = self.file_path if real else self.file_path_fake\n\n        data_text = join(data_path, file_text)\n        data_path = join(data_path, file_folder)\n\n        file_handler = open(data_text)\n        contents = file_handler.readlines()\n        if self.is_train:\n            contents_lst = contents[:val_num]\n        else:\n            contents_lst = contents[val_num:]\n\n        for content in contents_lst:\n            if '.npy' not in content and 'mask' not in content:\n                img_name = content.strip()\n                img_name = join(data_path, img_name)\n                result_list.append(img_name)\n        file_handler.close()\n\n        ## only truncate the val_num images. \n        if len(result_list) < val_num:\n            mul_factor  = (val_num//len(result_list)) + 2\n            result_list = result_list * mul_factor\n        result_list = result_list[-val_num:]\n        return result_list\n\n    def get_item(self, index):\n        '''\n            given the index, this function returns the image with the forgery mask\n            this function calls get_image, get_mask for the image and mask torch tensor.\n        '''\n        image_name = self.image_names[index]\n        cls = self.get_cls(image_name)\n\n        # image and mask\n        aug_index = randrange(0, 8)\n        image = self.get_image(image_name, aug_index)\n        mask, return_res = self.get_mask(image_name, cls, aug_index)\n\n        return image, mask, return_res[0], return_res[1], return_res[2], return_res[3]\n\n    def _img_list_retrieve(self):\n        '''Returns image list for different authentic and forgery image.'''\n        authentic_names = self.img_retrieve('authentic.txt', 'authentic')\n        splice_names     = self.img_retrieve('splice_randmask.txt', 'splice_randmask/fake',False)\n        inpainting_names = self.img_retrieve('Inpainting.txt', 'Inpainting/fake', False)\n        copymove_names   = self.img_retrieve('copy_move.txt', 'CopyMove', False)\n        STGAN_names   = self.img_retrieve('STGAN.txt', 'STGAN/fake', False)\n        FaShifter_names  = self.img_retrieve('FaShifter.txt', 'FaShifter', False)\n        return [authentic_names, splice_names, inpainting_names, copymove_names, STGAN_names, FaShifter_names]\n\nclass ValData(BaseData):\n    '''\n        The dataset used for the IFDL dataset.\n    '''\n    def __init__(self, args):\n        self.is_train  = False\n        self.val_num   = 900\n        super(ValData, self).__init__(args)\n\n    def img_retrieve(self, file_text, file_folder, real=True):\n        '''\n            Parameters:\n                file_text: str, text file for images.\n                file_folder: str, images folder.\n            Returns:\n                the image list.\n        '''\n        result_list = []\n        val_num   = self.val_num * 3 if file_text in [\"Youtube\", \"FaShifter\"] else self.val_num\n        data_path = self.file_path if real else self.file_path_fake\n\n        data_text = join(data_path, file_text)\n        data_path = join(data_path, file_folder)\n\n        file_handler = open(data_text)\n        contents = file_handler.readlines()\n        for content in contents[-val_num:]:\n            if '.npy' not in content and 'mask' not in content:\n                img_name = content.strip()\n                img_name = join(data_path, img_name)\n                result_list.append(img_name)\n        file_handler.close()\n\n        ## only truncate the val_num images. \n        if len(result_list) < val_num:\n            mul_factor  = (val_num//len(result_list)) + 2\n            result_list = result_list * mul_factor\n        result_list = result_list[-val_num:]\n        return result_list\n    \n    def get_item(self, index):\n        '''\n            given the index, this function returns the image with the forgery mask\n            this function calls get_image, get_mask for the image and mask torch tensor.\n        '''\n        image_name = self.image_names[index]\n        cls = self.get_cls(image_name)\n\n        # image\n        image = self.get_image(image_name)\n        mask, return_res = self.get_mask(image_name, cls)\n\n        return image, mask, return_res[0], return_res[1], return_res[2], return_res[3], image_name\n\n    def _img_list_retrieve(self):\n        '''Returns image list for different authentic and forgery image.'''\n        STGAN_names   = self.img_retrieve('STGAN.txt', 'STGAN/fake', False)\n        FaShifter_names  = self.img_retrieve('FaShifter.txt', 'FaShifter', False)\n        return [STGAN_names, FaShifter_names]\n"
  },
  {
    "path": "utils/load_edata.py",
    "content": "from PIL import Image\nfrom torchvision import transforms\nfrom os.path import join\nimport abc\nimport numpy as np\nimport torch\nimport torch.utils.data as data\nimport imageio\nimport os\n\nclass BaseData(data.Dataset):\n    '''\n        The dataset used for the IFDL dataset.\n    '''\n    def __init__(self, args):\n        super(BaseData, self).__init__()\n        self.crop_size = args.crop_size\n        ## demo dataset:\n        self.mani_data_dir = './data_dir'\n        ## the full dataset:\n        # self.mani_data_dir = './data'\n        self.image_names = []\n        self.image_class = []\n        self.mask_names  = []\n\n    def __getitem__(self, index):\n        res = self.get_item(index)\n        return res\n\n    def __len__(self):\n        return len(self.image_names)\n\n    def generate_mask(self, mask):\n        '''\n            generate the corresponding binary mask.\n        '''\n        mask = mask.astype(np.float32) / 255\n        mask[mask > 0.5] = 1\n        mask[mask <= 0.5] = 0\n        mask = np.expand_dims(mask, axis=0)\n        mask = torch.from_numpy(mask)\n        return mask\n\n    def rgba2rgb(self, rgba, background=(255, 255, 255)):\n        '''\n            turn rgba to rgb.\n        '''\n        row, col, ch = rgba.shape\n        rgb = np.zeros((row, col, 3), dtype='float32')\n        r, g, b, a = rgba[:, :, 0], rgba[:, :, 1], rgba[:, :, 2], rgba[:, :, 3]\n        a = np.asarray(a, dtype='float32') / 255.0\n\n        R, G, B = background\n        rgb[:, :, 0] = r * a + (1.0 - a) * R\n        rgb[:, :, 1] = g * a + (1.0 - a) * G\n        rgb[:, :, 2] = b * a + (1.0 - a) * B\n        return np.asarray(rgb, dtype='uint8') # the output value is uint8 that belongs to [0,255]\n\n    def get_image(self, image_name):\n        '''\n            return the image with the tensor.\n        '''\n        image = imageio.imread(image_name)\n        if len(image.shape) == 2:\n            image = imageio.imread(image_name, as_gray=False, pilmode=\"RGB\")\n        if image.shape[-1] == 4:\n            image = self.rgba2rgb(image)\n        image = torch.from_numpy(image.astype(np.float32) / 255)\n        return image.permute(2, 0, 1)\n\n    def get_mask(self, mask_name):\n        '''\n            return the binary mask.\n        '''   \n        mask = Image.open(mask_name).convert('L')\n        mask = mask.resize(self.crop_size, resample=Image.BICUBIC)\n        mask = np.asarray(mask)\n        mask = self.generate_mask(mask) \n        return mask\n\n    @abc.abstractmethod\n    def get_item(self, index):\n        '''\n            blur\n            image = Image.fromarray(image)\n            image = image.filter(ImageFilter.GaussianBlur(radius=7))\n            image = np.asarray(image)\n\n            resize\n            image = Image.fromarray(image)\n            image = image.resize((int(image.width*0.25), int(image.height*0.25)), resample=Image.BILINEAR)\n            image = np.asarray(image)\n\n            noise\n            import skimage\n            image = skimage.util.random_noise(image/255., mode='gaussian', mean=0, var=15/255) * 255\n\n            jpeg compression\n            im = Image.open(image_name)\n            temp_name = './temp/' + image_name.split('/')[-1][:-3] + 'jpg'\n            im.save(temp_name, 'JPEG', quality=50)\n            image = Image.open(temp_name)\n            image = np.asarray(image)\n        '''\n        pass\n\nclass ValColumbia(BaseData):\n    def __init__(self, args):\n        super(ValColumbia, self).__init__(args)\n        ddir = os.path.join(self.mani_data_dir, 'columbia')\n        with open(join(ddir, 'vallist.txt')) as f:\n            contents = f.readlines()\n            for content in contents:\n                _ = os.path.join(ddir, '4cam_splc', content.strip())\n                self.image_names.append(_)\n        self.image_class = [1] * len(self.image_names)\n\n    def get_item(self, index):\n        image_name = self.image_names[index]\n        cls = self.image_class[index]\n\n        # image\n        image = self.get_image(image_name)\n\n        # mask\n        if '4cam_splc' in image_name:\n            mask_name = image_name.replace('4cam_splc', 'mask').replace('.tif', '.jpg')\n            mask = self.get_mask(mask_name)\n        else:\n            mask = np.zeros((1, 256, 256), dtype='float32')\n\n        return image, mask, cls, image_name\n\nclass ValCoverage(BaseData):\n    def __init__(self, args):\n        super(ValCoverage, self).__init__(args)\n        ddir = os.path.join(self.mani_data_dir, 'Coverage')\n        with open(join(ddir, 'fake.txt')) as f:\n            contents = f.readlines()\n            for content in contents:\n                _ = os.path.join(ddir, 'image', content.strip())\n                self.image_names.append(_)\n        self.image_class = [2] * len(self.image_names)\n\n    def get_item(self, index):\n        image_name = self.image_names[index]\n        cls = self.image_class[index]\n\n        ## read image.\n        image = self.get_image(image_name)\n\n        # mask\n        mask_name = image_name.replace('image', 'mask').replace('t.tif', 'forged.tif')\n        mask = self.get_mask(mask_name)\n\n        return image, mask, cls, image_name\n\nclass ValCasia(BaseData):\n    def __init__(self, args):\n        super(ValCasia, self).__init__(args)\n        ddir = os.path.join(self.mani_data_dir, 'CASIA/CASIA1')\n        with open(join(ddir, 'fake.txt')) as f:\n            contents = f.readlines()\n            for content in contents:\n                tag = content.split('/')[-1].split('_')[1]\n                if tag == 'D':\n                    self.image_class.append(1)\n                elif tag == 'S':\n                    self.image_class.append(2)\n                else:\n                    raise Exception('unknown class: {}'.format(content))\n                self.image_names.append(os.path.join(ddir, 'fake', content.strip()))\n\n        ddir = os.path.join(self.mani_data_dir, 'CASIA/CASIA2')\n        with open(join(ddir, 'fake.txt')) as f:\n            contents = f.readlines()\n            for content in contents:\n                tag = content.split('/')[-1].split('_')[1]\n                if tag == 'D':\n                    self.image_class.append(1)\n                elif tag == 'S':\n                    self.image_class.append(2)\n                else:\n                    raise Exception('unknown class: {}'.format(content))\n                self.image_names.append(os.path.join(ddir, 'fake', content.strip()))\n\n    def get_item(self, index):\n        image_name = self.image_names[index]\n        cls = self.image_class[index]\n\n        # image\n        image = self.get_image(image_name)\n\n        # mask\n        if '.jpg' in image_name:\n            mask_name = image_name.replace('fake', 'mask').replace('.jpg', '_gt.png')\n        else:\n            mask_name = image_name.replace('fake', 'mask').replace('.tif', '_gt.png')\n        mask = self.get_mask(mask_name)\n\n        return image, mask, cls, image_name\n\nclass ValNIST16(BaseData):\n    def __init__(self, args):\n        super(ValNIST16, self).__init__(args)\n        ddir = os.path.join(self.mani_data_dir, 'NIST16')\n        file_name = 'alllist.txt'\n        with open(join(ddir, file_name)) as f:\n            contents = f.readlines()\n            for content in contents:\n                image_name, mask_name = content.split(' ')\n                self.image_names.append(join(ddir, image_name))\n                self.mask_names.append(join(ddir, mask_name.strip()))\n\n    def get_item(self, index):\n        image_name = self.image_names[index]\n        mask_name = self.mask_names[index]\n\n        if 'splice' in mask_name:\n            cls = 1\n        elif 'manipulation' in mask_name:\n            cls = 2\n        elif 'remove' in mask_name:\n            cls = 3\n        else:\n            cls = 0\n\n        # image\n        image = self.get_image(image_name)\n        if image.size()[2]*image.size()[1] >= 1000*1000:\n            image = imageio.imread(image_name)\n            if image.shape[-1] == 4:\n                image = self.rgba2rgb(image)\n            image = Image.fromarray(image)\n            image = image.resize((1000, 1000), resample=Image.BICUBIC)\n            image = np.asarray(image)\n            image = torch.from_numpy(image.astype(np.float32) / 255)\n            image = image.permute(2, 0, 1)\n\n        # mask\n        mask = self.get_mask(mask_name)\n        mask = torch.abs(mask - 1)\n\n        return image, mask, cls, image_name\n\nclass ValIMD2020(BaseData):\n    def __init__(self, args):\n        super(ValIMD2020, self).__init__(args)\n        ddir = os.path.join(self.mani_data_dir, 'IMD2020')\n        file_name = 'fake.txt'\n        with open(join(ddir, file_name)) as f:\n            contents = f.readlines()\n            for content in contents:\n                image_name = content.strip()\n                if '.jpg' in image_name:\n                    mask_name = image_name.replace('.jpg', '_mask.png')\n                else:\n                    mask_name = image_name.replace('.png', '_mask.png')\n                self.image_names.append(join(ddir, 'fake_img', image_name))\n                self.mask_names.append(join(ddir, 'mask', mask_name))\n        self.image_class = [2] * len(self.image_names)\n\n    def get_item(self, index):\n        image_name = self.image_names[index]\n        mask_name = self.mask_names[index]\n        cls = self.image_class[index]\n        try:\n            image = self.get_image(image_name)\n        except:\n            print(f\"Fail at {image_name}.\")\n        mask = self.get_mask(mask_name)\n\n        return image, mask, cls, image_name"
  },
  {
    "path": "utils/utils.py",
    "content": "# ------------------------------------------------------------------------------\n# Author: Xiao Guo (guoxia11@msu.edu)\n# CVPR2023: Hierarchical Fine-Grained Image Forgery Detection and Localization\n# ------------------------------------------------------------------------------\nimport os\nimport time\nimport torch\nimport torch.nn as nn\nimport numpy as np\nimport torch.nn.functional as F\nimport matplotlib.pyplot as plt\n\nfrom torch.optim.lr_scheduler import ReduceLROnPlateau\nfrom kmeans_pytorch import kmeans\nfrom torchvision import transforms\nfrom torch.utils.data import DataLoader\nfrom sklearn import metrics\nfrom torchvision.utils import make_grid\nfrom einops import rearrange\nfrom PIL import Image\n\nSoftmax_m = nn.Softmax(dim=1)\ndevice = torch.device('cuda:0')\n\ndef device_ids_return(cuda_list):\n    '''return the device id'''\n    if  len(cuda_list) == 1:\n        device_ids = [0]\n    elif len(cuda_list) == 2:\n        device_ids = [0,1]\n    elif len(cuda_list) == 3:\n        device_ids = [0,1,2]\n    elif len(cuda_list) == 4:\n        device_ids = [0,1,2,3]\n    elif len(cuda_list) == 7:\n        device_ids = [0,1,2,3,4,5,6]\n    return device_ids\n\ndef findLastCheckpoint(save_dir):\n    if os.path.exists(save_dir):\n        file_list = os.listdir(save_dir)\n        result = 0\n        for file in file_list:\n            try:\n                num = int(file.split('.')[0].split('_')[-1])\n                result = max(result, num)\n            except:\n                continue\n        return result\n    else:\n        os.mkdir(save_dir)\n        return 0\n\ndef get_confusion_matrix(y_true, y_pred):\n    return metrics.confusion_matrix(y_true, y_pred)\n\ndef compute_cls_acc_f1(label_lst, pred_lst, iter_num, tb_writer, descr='Val/level3_1'):\n    F1  = metrics.f1_score(label_lst, pred_lst, average='macro')\n    acc = metrics.accuracy_score(label_lst, pred_lst)\n    tb_writer.add_scalar(f'{descr}_F1', F1, iter_num)\n    tb_writer.add_scalar(f'{descr}_acc', acc, iter_num)\n    print(f\"In {descr}, the image-level Acc: {acc:.3f}, F1: {F1:.3f}.\")\n    print(\"******************************************************\")\n    return F1, acc\n\ndef tb_writer_display(writer, iter_num, lr_scheduler, epoch, \n                      seg_accu, loc_map_loss, manipul_loss, natural_loss, binary_loss,\n                      loss_1, loss_2, loss_3, loss_4):\n    writer.add_scalar('Train/seg_accu', seg_accu, iter_num)\n    writer.add_scalar('Train/map_loss', loc_map_loss, iter_num)\n    writer.add_scalar('Train/binary_map_loss', binary_loss, iter_num)\n    writer.add_scalar('Train/manip_loss', manipul_loss, iter_num)\n    writer.add_scalar('Train/natur_loss', natural_loss, iter_num)\n    writer.add_scalar('Train/loss_1', loss_1, iter_num)\n    writer.add_scalar('Train/loss_2', loss_2, iter_num)\n    writer.add_scalar('Train/loss_3', loss_3, iter_num)\n    writer.add_scalar('Train/loss_4', loss_3, iter_num)\n    for count, gp in enumerate(lr_scheduler.optimizer.param_groups,1):\n        writer.add_scalar('progress/lr_%d'%count, gp['lr'], iter_num)\n    writer.add_scalar('progress/epoch', epoch, iter_num)\n    writer.add_scalar('progress/curr_patience',lr_scheduler.num_bad_epochs,iter_num)\n    writer.add_scalar('progress/patience',lr_scheduler.patience,iter_num)\n\ndef one_hot_label(vector, Softmax_m=Softmax_m):\n    x = Softmax_m(vector)\n    x = torch.argmax(x, dim=1)\n    return x\n\ndef one_hot_label_new(vector, Softmax_m=Softmax_m):\n    '''\n        compute the probability for being as the synthesized image (TODO: double check).\n    '''\n    x = Softmax_m(vector)\n    indices = torch.argmax(x, dim=1)\n    prob = 1 - x[:,0]\n    indices = list(indices.cpu().numpy())\n    prob = list(prob.cpu().numpy())\n    return indices, prob\n\ndef level_1_convert(input_lst):\n    res_lst = []\n    for _ in input_lst:\n        if _ == 0:\n            res_lst.append(0)\n        else:\n            res_lst.append(1)\n    return res_lst\n\ndef confusion_matrix_display(label_lst, res_lst, display_lst, display_name):\n    confusion_matrix = metrics.confusion_matrix(label_lst, res_lst)\n    cm_display = metrics.ConfusionMatrixDisplay(confusion_matrix = confusion_matrix, \n                                                display_labels = display_lst)\n    cm_display.plot()\n    plt.savefig(f'{display_name}.png')\n    confusion_matrix = metrics.confusion_matrix(label_lst, res_lst, normalize='true')\n    cm_display = metrics.ConfusionMatrixDisplay(confusion_matrix = confusion_matrix, \n                                                display_labels = display_lst)\n    cm_display.plot()\n    plt.savefig(f'{display_name}_normalized.png')\n\ndef make_folder(folder_name):\n    if not os.path.exists(folder_name):\n        os.makedirs(folder_name, exist_ok=True)\n        print(f\"Making folder {folder_name}.\")\n    else:\n        print(f\"Folder {folder_name} exists.\")\n\ndef class_weight(mask, mask_idx):\n    '''balance the weight on real and forgery pixel.'''\n    mask_balance = torch.ones_like(mask).to(torch.float)\n    if (mask == 1).sum():\n        mask_balance[mask == 1] = 0.5 / ((mask == 1).sum().to(torch.float) / mask.numel())\n        mask_balance[mask == 0] = 0.5 / ((mask == 0).sum().to(torch.float) / mask.numel())\n    else:\n        pass\n        # print(f'Mask{mask_idx} balance is not working!')\n    return mask.to(device), mask_balance.to(device)\n\ndef setup_optimizer(args, SegNet, FENet):\n    '''setup the optimizier, which applies different learning rate on different layers.'''\n    '''different hyper-parameters are changed towards HiFi-IFDL dataset.'''\n    params_dict_list = []\n    params_dict_list.append({'params' : SegNet.module.parameters(), 'lr' : args.learning_rate})\n    freq_list = []\n    para_list = []\n    for name, param in FENet.named_parameters():\n        if 'fre' in name:\n            freq_list += [param]\n        else:\n            para_list += [param]\n    params_dict_list.append({'params' : freq_list, 'lr' : args.learning_rate*args.lr_backbone})\n    params_dict_list.append({'params' : para_list, 'lr' : args.learning_rate})\n\n    optimizer    = torch.optim.Adam(params_dict_list, lr=args.learning_rate*0.75, weight_decay=1e-06)\n    lr_scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=args.step_factor, min_lr=1e-08,\n                                     patience=args.patience, verbose=True)\n\n    return optimizer, lr_scheduler\n\ndef restore_weight_helper(model, model_dir, initial_epoch):\n    '''load model given the model_dir that has the model weights.'''\n    try:\n        weight_path = '{}/{}.pth'.format(model_dir, initial_epoch)\n        state_dict = torch.load(weight_path, map_location='cuda:0')['model']\n        model.load_state_dict(state_dict)\n        print('{} weight-loading succeeds: {}'.format(model_dir, weight_path))\n    except:\n        print('{} weight-loading fails'.format(model_dir))\n\n    pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n    print(\"{}_params: {}\".format(model_dir, pytorch_total_params))\n    return model\n\ndef restore_optimizer(optimizer, model_dir):\n    '''restore the optimizer.'''\n    try:\n        weight_path = '{}/{}.pth'.format(model_dir, initial_epoch)\n        state_dict = torch.load(weight_path, map_location='cuda:0')\n        print('Optimizer weight-loading succeeds.')\n        optimizer.load_state_dict(state_dict['optimizer'])\n    except:\n        # print('{} Optimizer weight-loading fails.')\n        pass\n    return optimizer\n\ndef composite_obj(args, loss, loss_1, loss_2, loss_3, loss_4, loss_binary):\n    ''' 'base', 'fg', 'local', 'full' '''\n    if args.ablation == 'full':     # fine-grained + localization\n        loss_total = 100*loss + loss_1 + loss_2 + loss_3 + 100*loss_4 + loss_binary\n    elif args.ablation == 'base':   # one-shot\n        loss_total = loss_4\n    elif args.ablation == 'fg':     # only fine-grained\n        loss_total = loss_1 + loss_2 + loss_3 + loss_4\n    elif args.ablation == 'local':  # only loclization\n        loss_total = loss + 10e-6*(loss_1 + loss_2 + loss_3 + loss_4)\n    else:\n        assert False\n    return loss_total\n\ndef composite_obj_step(args, loss_4_sum, map_loss_sum):\n    ''' return loss for the scheduler '''\n    if args.ablation == 'full':\n        schedule_step_loss = loss_4_sum + map_loss_sum\n    elif args.ablation == 'base':\n        schedule_step_loss = loss_4_sum\n    elif args.ablation == 'fg':\n        schedule_step_loss = loss_4_sum\n    elif args.ablation == 'local':\n        schedule_step_loss = map_loss_sum\n    else:\n        assert False\n    return schedule_step_loss\n\ndef viz_log(args, mask, pred_mask, image, iter_num, step, mode='train'):\n    '''viz training, val and inference.'''\n    mask = torch.unsqueeze(mask, dim=1)\n    pred_mask = torch.unsqueeze(pred_mask, dim=1)\n    mask_viz = torch.cat([mask]*3, axis=1)\n    pred_mask = torch.cat([pred_mask]*3, axis=1)\n    image = torch.nn.functional.interpolate(image,  # for viz.\n                                          size=(256, 256), \n                                          mode='bilinear')\n    fig_viz = torch.cat([mask_viz, image, pred_mask], axis=0)\n    grid = make_grid(fig_viz, nrow=mask_viz.shape[0])   # nrow in fact is the column number.\n    grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()\n    img_h = Image.fromarray(grid.astype(np.uint8))\n    # os.makedirs(f\"./viz_{mode}_{args.learning_rate}/\", exist_ok=True)\n    os.makedirs(f\"./viz_{mode}/\", exist_ok=True)\n    if mode == 'train':\n        # img_h.save(f\"./viz_{mode}_{args.learning_rate}/iter_{iter_num}.jpg\")\n        img_h.save(f\"./viz_{mode}/iter_{iter_num}.jpg\")\n    else:\n        # img_h.save(f\"./viz_{mode}_{args.learning_rate}/iter_{iter_num}_step_{step}.jpg\")\n        img_h.save(f\"./viz_{mode}/iter_{iter_num}_step_{step}.jpg\")\n\ndef process_mask(mask, pred_mask):\n    '''process the mask'''\n    pred_mask = torch.unsqueeze(pred_mask, dim=1)\n    mask = torch.unsqueeze(mask, dim=1)\n    pred_mask = torch.cat([pred_mask]*3, axis=1)\n    mask = torch.cat([mask]*3, axis=1)\n\n    pred_mask = nn.functional.interpolate(pred_mask, \n                                        size=(256, 256), mode='bilinear')\n    mask = nn.functional.interpolate(mask, \n                                    size=(256, 256), mode='bilinear')\n\n    return pred_mask, mask\n\ndef viz_logs_scale(args, iter_num, mask_128, mask_64, mask_32, mask2, mask3, mask4, mode='train'):\n    '''visualize the mask and predicted mask.'''\n    pred_mask_128, mask128 = process_mask(mask_128, mask2)\n    pred_mask_64, mask64 = process_mask(mask_64, mask3)\n    pred_mask_32, mask32 = process_mask(mask_32, mask4)\n\n    fig_viz = torch.cat([pred_mask_32, mask32, pred_mask_64, mask64, \n                        pred_mask_128, mask128], axis=0)\n    grid = make_grid(fig_viz, nrow=pred_mask_32.shape[0])\n    grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()\n    img_h = Image.fromarray(grid.astype(np.uint8))\n    os.makedirs(f\"./viz_{mode}_{args.learning_rate}/\", exist_ok=True)\n    img_h.save(f\"./viz_{mode}_{args.learning_rate}/iter_{iter_num}_pred.jpg\")\n\ndef train_log_dump(args, seg_correct, seg_total, map_loss_sum, mani_lss_sum, natu_lss_sum,\n                    binary_map_loss_sum, loss_1_sum, loss_2_sum, loss_3_sum,\n                    loss_4_sum, epoch, step, writer, iter_num, lr_scheduler):\n    '''compute and output the different training loss & seg accuarcy.'''\n    seg_accu = seg_correct / seg_total * 100\n    loc_map_loss = map_loss_sum / args.dis_step\n    manipul_loss = mani_lss_sum / args.dis_step\n    natural_loss = natu_lss_sum / args.dis_step\n    binary_loss  = binary_map_loss_sum / args.dis_step\n    loss_1 = loss_1_sum / args.dis_step\n    loss_2 = loss_2_sum / args.dis_step\n    loss_3 = loss_3_sum / args.dis_step\n    loss_4 = loss_4_sum / args.dis_step\n    print(f'[Epoch: {epoch+1}, Step: {step + 1}] batch_loc_acc: {seg_accu:.2f}')\n    print(f'cls1_loss: {loss_1:.3f}, cls2_loss: {loss_2:.3f}, cls3_loss: {loss_3:.3f}, '+\n          f'cls4_loss: {loss_4:.3f}, map_loss:   {loc_map_loss:.3f}, '+\n          f'manip_loss: {manipul_loss:.3f}, natur_loss: {natural_loss:.3f}, '+\n          f'binary_map_loss: {binary_loss:.3f}') \n    '''write the tensorboard.'''\n    tb_writer_display(writer, iter_num, lr_scheduler, epoch, seg_accu, \n                      loc_map_loss, manipul_loss, natural_loss, binary_loss,\n                      loss_1, loss_2, loss_3, loss_4)"
  },
  {
    "path": "weights/put_weights_here",
    "content": ""
  }
]