Repository: CHELSEA234/HiFi_IFDL Branch: main Commit: 0ca70d651087 Files: 106 Total size: 75.3 MB Directory structure: gitextract_j0fr7ihb/ ├── HiFi_Net.py ├── HiFi_Net_loc.py ├── HiFi_Net_loc.sh ├── IMD_dataloader.py ├── LICENSE ├── README.md ├── applications/ │ ├── CNNImage_detection/ │ │ └── README.md │ ├── DiffVideo_detection/ │ │ └── README.md │ └── deepfake_detection/ │ ├── FF++/ │ │ └── put_weight_here │ ├── README.md │ ├── dataset_test.py │ ├── dataset_test.sh │ ├── environment.yml │ ├── exp_FF_c40_bs_32_lr_0.0001_ws_10.txt │ ├── sequence/ │ │ ├── models/ │ │ │ ├── GaussianSmoothing.py │ │ │ ├── HiFiNet_deepfake.py │ │ │ ├── LaPlacianMs.py │ │ │ ├── hrnet/ │ │ │ │ ├── hrnet_w18_small_model_v2.pth │ │ │ │ ├── seg_hrnet.py │ │ │ │ └── seg_hrnet_config.py │ │ │ └── run_model.sh │ │ ├── rnn_stratified_dataloader.py │ │ ├── runjobs_utils.py │ │ └── torch_utils.py │ ├── test.py │ ├── test.sh │ ├── train.py │ └── train.sh ├── center/ │ └── radius_center.pth ├── center_loc/ │ └── radius_center.pth ├── data_dir/ │ ├── CASIA/ │ │ ├── CASIA1/ │ │ │ └── fake.txt │ │ └── CASIA2/ │ │ ├── fake/ │ │ │ ├── Tp_D_CND_M_N_ani00018_sec00096_00138.tif │ │ │ ├── Tp_D_CND_M_N_art00076_art00077_10289.tif │ │ │ └── Tp_D_CND_M_N_art00077_art00076_10290.tif │ │ └── fake.txt │ ├── Coverage/ │ │ ├── fake.txt │ │ ├── image/ │ │ │ ├── 10t.tif │ │ │ ├── 11t.tif │ │ │ ├── 12t.tif │ │ │ ├── 13t.tif │ │ │ ├── 14t.tif │ │ │ ├── 15t.tif │ │ │ ├── 16t.tif │ │ │ ├── 17t.tif │ │ │ ├── 18t.tif │ │ │ ├── 19t.tif │ │ │ └── 1t.tif │ │ └── mask/ │ │ ├── 10copy.tif │ │ ├── 10forged.tif │ │ ├── 10paste.tif │ │ ├── 11copy.tif │ │ ├── 11forged.tif │ │ ├── 11paste.tif │ │ ├── 12copy.tif │ │ ├── 12forged.tif │ │ ├── 12paste.tif │ │ ├── 13copy.tif │ │ ├── 13forged.tif │ │ ├── 13paste.tif │ │ ├── 14copy.tif │ │ ├── 14forged.tif │ │ ├── 14paste.tif │ │ ├── 15copy.tif │ │ ├── 15forged.tif │ │ ├── 15paste.tif │ │ ├── 16copy.tif │ │ ├── 16forged.tif │ │ ├── 16paste.tif │ │ ├── 17copy.tif │ │ ├── 17forged.tif │ │ ├── 17paste.tif │ │ ├── 18copy.tif │ │ ├── 18forged.tif │ │ ├── 18paste.tif │ │ ├── 19copy.tif │ │ ├── 19forged.tif │ │ ├── 19paste.tif │ │ ├── 1copy.tif │ │ ├── 1forged.tif │ │ └── 1paste.tif │ ├── IMD2020/ │ │ └── fake.txt │ ├── NIST16/ │ │ └── alllist.txt │ └── columbia/ │ ├── 4cam_splc/ │ │ ├── canong3_canonxt_sub_01.tif │ │ ├── canong3_canonxt_sub_02.tif │ │ ├── canong3_canonxt_sub_03.tif │ │ ├── canong3_canonxt_sub_04.tif │ │ ├── canong3_canonxt_sub_05.tif │ │ ├── canong3_canonxt_sub_06.tif │ │ ├── canong3_canonxt_sub_07.tif │ │ ├── canong3_canonxt_sub_08.tif │ │ └── canong3_canonxt_sub_09.tif │ └── vallist.txt ├── environment.yml ├── models/ │ ├── GaussianSmoothing.py │ ├── LaPlacianMs.py │ ├── NLCDetection_api.py │ ├── NLCDetection_loc.py │ ├── NLCDetection_pconv.py │ ├── hrnet_w18_small_v2.pth │ ├── seg_hrnet.py │ └── seg_hrnet_config.py ├── utils/ │ ├── custom_loss.py │ ├── load_data.py │ ├── load_edata.py │ └── utils.py └── weights/ └── put_weights_here ================================================ FILE CONTENTS ================================================ ================================================ FILE: HiFi_Net.py ================================================ # ------------------------------------------------------------------------------ # Author: Xiao Guo (guoxia11@msu.edu) # CVPR2023: Hierarchical Fine-Grained Image Forgery Detection and Localization # ------------------------------------------------------------------------------ from utils.utils import * from utils.custom_loss import IsolatingLossFunction, load_center_radius_api from models.seg_hrnet import get_seg_model from models.seg_hrnet_config import get_cfg_defaults from models.NLCDetection_api import NLCDetection from PIL import Image import torch import torch.nn as nn import numpy as np import argparse import imageio as imageio class HiFi_Net(): ''' FENET is the multi-branch feature extractor. SegNet contains the classification and localization modules. LOSS_MAP is the classification loss function class. ''' def __init__(self): device = torch.device('cuda:0') device_ids = [0] FENet_cfg = get_cfg_defaults() FENet = get_seg_model(FENet_cfg).to(device) # load the pre-trained model inside. SegNet = NLCDetection().to(device) FENet = nn.DataParallel(FENet) SegNet = nn.DataParallel(SegNet) self.FENet = restore_weight_helper(FENet, "weights/HRNet", 750001) self.SegNet = restore_weight_helper(SegNet, "weights/NLCDetection", 750001) self.FENet.eval() self.SegNet.eval() center, radius = load_center_radius_api() self.LOSS_MAP = IsolatingLossFunction(center,radius).to(device) def _transform_image(self, image_name): '''transform the image.''' image = imageio.imread(image_name) image = Image.fromarray(image) image = image.resize((256,256), resample=Image.BICUBIC) image = np.asarray(image) image = image.astype(np.float32) / 255. image = torch.from_numpy(image) image = image.permute(2, 0, 1) image = torch.unsqueeze(image, 0) return image def _normalized_threshold(self, res, prob, threshold=0.5, verbose=False): '''to interpret detection result via omitting the detection decision.''' if res > threshold: decision = "Forged" prob = (prob - threshold) / threshold else: decision = 'Real' prob = (threshold - prob) / threshold print(f'Image being {decision} with the confidence {prob*100:.1f}.') def detect(self, image_name, verbose=False): """ Para: image_name is string type variable for the image name. Return: res: binary result for real and forged. prob: the prob being the forged image. """ with torch.no_grad(): img_input = self._transform_image(image_name) output = self.FENet(img_input) mask1_fea, mask1_binary, out0, out1, out2, out3 = self.SegNet(output, img_input) res, prob = one_hot_label_new(out3) res = level_1_convert(res)[0] if not verbose: return res, prob[0] else: self._normalized_threshold(res, prob[0]) def localize(self, image_name): """ Para: image_name is string type variable for the image name. Return: binary_mask: forgery mask. """ with torch.no_grad(): img_input = self._transform_image(image_name) output = self.FENet(img_input) mask1_fea, mask1_binary, out0, out1, out2, out3 = self.SegNet(output, img_input) pred_mask, pred_mask_score = self.LOSS_MAP.inference(mask1_fea) # inference pred_mask_score = pred_mask_score.cpu().numpy() ## 2.3 is the threshold used to seperate the real and fake pixels. ## 2.3 is the dist between center and pixel feature in the hyper-sphere. ## for center and pixel feature please refer to "IsolatingLossFunction" in custom_loss.py pred_mask_score[pred_mask_score<2.3] = 0. pred_mask_score[pred_mask_score>=2.3] = 1. binary_mask = pred_mask_score[0] return binary_mask def inference(img_path): HiFi = HiFi_Net() # initialize ## detection res3, prob3 = HiFi.detect(img_path) # print(res3, prob3) 1 1.0 HiFi.detect(img_path, verbose=True) ## localization binary_mask = HiFi.localize(img_path) binary_mask = Image.fromarray((binary_mask*255.).astype(np.uint8)) binary_mask.save('pred_mask.png') if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('--img_path', type=str, default='asset/sample_1.jpg') args = parser.parse_args() inference(args.img_path) ================================================ FILE: HiFi_Net_loc.py ================================================ # ------------------------------------------------------------------------------ # Author: Xiao Guo (guoxia11@msu.edu) # CVPR2023: Hierarchical Fine-Grained Image Forgery Detection and Localization # ------------------------------------------------------------------------------ from utils.utils import * from IMD_dataloader import * from utils.custom_loss import IsolatingLossFunction, load_center_radius from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm from models.seg_hrnet import get_seg_model from models.seg_hrnet_config import get_cfg_defaults from models.NLCDetection_loc import NLCDetection from sklearn import metrics from sklearn.metrics import roc_auc_score from torchvision.utils import make_grid from einops import rearrange from PIL import Image from sklearn import metrics import os import csv import time import torch import torch.nn as nn import argparse import numpy as np device = torch.device('cuda:0') device_ids = [0] def config(args): '''Set up input configurations.''' args.crop_size = [args.crop_size, args.crop_size] # cuda_list = args.list_cuda global device device = torch.device('cuda:0') # global device_ids # device_ids = device_ids_return(cuda_list) args.save_dir = 'lr_' + str(args.learning_rate)+'_loc' FENet_dir, SegNet_dir = args.save_dir+'/HRNet', args.save_dir+'/NLCDetection' FENet_cfg = get_cfg_defaults() FENet = get_seg_model(FENet_cfg).to(device) # load the pre-trained model inside. SegNet = NLCDetection().to(device) FENet = nn.DataParallel(FENet, device_ids=device_ids) SegNet = nn.DataParallel(SegNet, device_ids=device_ids) writer = None return args, writer, FENet, SegNet, FENet_dir, SegNet_dir def restore_weight(args, FENet, SegNet, FENet_dir, SegNet_dir): '''load FENet, SegNet and optimizer.''' params = list(FENet.parameters()) + list(SegNet.parameters()) optimizer = torch.optim.Adam(params, lr=args.learning_rate) initial_epoch = findLastCheckpoint(save_dir=SegNet_dir) # load FENet and SegNet weight: FENet = restore_weight_helper(FENet, FENet_dir, initial_epoch) SegNet = restore_weight_helper(SegNet, SegNet_dir, initial_epoch) optimizer = restore_optimizer(optimizer, SegNet_dir) return optimizer, initial_epoch def Inference_loc( args, FENet, SegNet, LOSS_MAP, tb_writer, iter_num=None, save_tag=False, localization=True ): ''' the inference pipeline for the pre-trained model. the image-level detection will dump to the csv file. the pixel-level localization will be saved as in the npy file. ''' for val_tag in [0,1,2,3,4]: val_data_loader, data_label = eval_dataset_loader_init(args, val_tag) print(f"working on the dataset: {data_label}.") F1_lst, auc_lst = [], [] with torch.no_grad(): FENet.eval() SegNet.eval() for step, val_data in enumerate(tqdm(val_data_loader)): image, mask, cls, image_names = val_data image, mask = image.to(device), mask.to(device) mask = torch.squeeze(mask, axis=1) # model try: output = FENet(image) mask1_fea, mask_binary, out0, out1, out2, out3 = SegNet(output, image) except: print(f"does not work on the ", image_names) continue if args.loss_type == 'dm': loss_map, loss_manip, loss_nat = LOSS_MAP(mask1_fea, mask) pred_mask = LOSS_MAP.dis_curBatch.squeeze(dim=1) pred_mask_score = LOSS_MAP.dist.squeeze(dim=1) elif args.loss_type == 'ce': pred_mask_score = mask_binary pred_mask = torch.zeros_like(mask_binary) pred_mask[mask_binary > 0.5] = 1 pred_mask[mask_binary <= 0.5] = 0 viz_log(args, mask, pred_mask, image, iter_num, f"{step}_{val_tag}", mode='eval') mask = torch.unsqueeze(mask, axis=1) for img_idx, cur_img_name in enumerate(image_names): mask_ = torch.unsqueeze(mask[img_idx,0], 0) pred_mask_ = torch.unsqueeze(pred_mask[img_idx], 0) pred_mask_score_ = torch.unsqueeze(pred_mask_score[img_idx], 0) mask_ = mask_.cpu().clone().cpu().numpy().reshape(-1) pred_mask_ = pred_mask_.cpu().clone().cpu().numpy().reshape(-1) pred_mask_score_ = pred_mask_score_.cpu().clone().cpu().numpy().reshape(-1) F1_a = metrics.f1_score(mask_, pred_mask_, average='macro') auc_a = metrics.roc_auc_score(mask_, pred_mask_score_) pred_mask_[np.where(pred_mask_ == 0)] = 1 pred_mask_[np.where(pred_mask_ == 1)] = 0 F1_b = metrics.f1_score(mask_, pred_mask_, average='macro') if F1_a > F1_b: F1 = F1_a else: F1 = F1_b F1_lst.append(F1) AUC_score = auc_a if auc_a > 0.5 else 1-auc_a auc_lst.append(AUC_score) print("F1: ", np.mean(F1_lst)) print("AUC: ", np.mean(auc_lst)) def main(args): ## Set up the configuration. args, writer, FENet, SegNet, FENet_dir, SegNet_dir = config(args) ## load FENet and SegNet weight: if args.loss_type == 'ce': FENet = restore_weight_helper(FENet, "weights/HRNet", 225000) SegNet = restore_weight_helper(SegNet, "weights/NLCDetection", 225000) elif args.loss_type == 'dm': FENet = restore_weight_helper(FENet, "weights/HRNet", 315000) SegNet = restore_weight_helper(SegNet, "weights/NLCDetection", 315000) else: raise ValueError ## Set up the loss function. center, radius = load_center_radius(args, FENet, SegNet, train_data_loader=None, center_radius_dir='./center_loc') CE_loss = nn.CrossEntropyLoss().to(device) BCE_loss = nn.BCELoss(reduction='none').to(device) LOSS_MAP = IsolatingLossFunction(center,radius).to(device) Inference_loc( args, FENet, SegNet, LOSS_MAP, tb_writer=writer, iter_num=99999, save_tag=True, localization=True ) print("after saving the points...") if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('-l','--list_cuda', nargs='+', help=' Set flag') parser.add_argument('-lr', '--learning_rate', type=float, default=5e-5) parser.add_argument('--num_epochs', type=int, default=3) parser.add_argument('--lr_gamma', type=float, default=2.0) parser.add_argument('--lr_backbone', type=float, default=0.9) parser.add_argument('--patience', type=int, default=30) parser.add_argument('--step_factor', type=float, default=0.95) parser.add_argument('--dis_step', type=int, default=50) parser.add_argument('--val_step', type=int, default=500) ## train hyper-parameters parser.add_argument('--crop_size', type=int, default=256) parser.add_argument('--val_num', type=int, default=200, help='val sample number.') parser.add_argument('--train_num', type=int, default=360000, help='train sample number.') parser.add_argument('--train_tag', type=int, default=0) parser.add_argument('--val_tag', type=int, default=0) parser.add_argument('--val_all', type=int, default=1) parser.add_argument('--ablation', type=str, default='local', choices=['base', 'fg', 'local', 'full'], help='exp for one-shot, fine_grain, plus localization, plus pconv') parser.add_argument('--val_loc_tag', action='store_true') parser.add_argument('--fine_tune', action='store_true') parser.add_argument('--debug_mode', action='store_true') parser.set_defaults(val_loc_tag=True) parser.set_defaults(fine_tune=True) parser.add_argument('--train_ratio', nargs='+', default="0.4 0.4 0.2", help='deprecated') parser.add_argument('--path', type=str, default="", help='deprecated') parser.add_argument('--train_bs', type=int, default=10, help='batch size in the training.') parser.add_argument('--val_bs', type=int, default=10, help='batch size in the validation.') parser.add_argument('--percent', type=float, default=1.0, help='label dataset.') parser.add_argument('--loss_type', type=str, default='ce', choices=['ce', 'dm'], help='ce or deep metric.') ## inference hyperparameters: parser.add_argument('--initial_epoch', type=int, default=70500) args = parser.parse_args() main(args) ================================================ FILE: HiFi_Net_loc.sh ================================================ source ~/.bashrc conda activate HiFi_Net CUDA_NUM=2 CUDA_VISIBLE_DEVICES=$CUDA_NUM python HiFi_Net_loc.py ================================================ FILE: IMD_dataloader.py ================================================ # ------------------------------------------------------------------------------ # Author: Xiao Guo, Xiaohong Liu. # CVPR2023: Hierarchical Fine-Grained Image Forgery Detection and Localization # ------------------------------------------------------------------------------ from torch.utils.data import DataLoader from utils.load_data import TrainData, ValData from utils.load_edata import * def train_dataset_loader_init(args): train_dataset = TrainData(args) train_data_loader = DataLoader( train_dataset, batch_size=args.train_bs, shuffle=True, # shuffle=False, num_workers=8 ) return train_data_loader def infer_dataset_loader_init(args, shuffle=True, bs=8): val_dataset = ValData(args) val_data_loader = DataLoader( val_dataset, batch_size=bs, shuffle=shuffle, # shuffle=True, num_workers=8 ) return val_data_loader def eval_dataset_loader_init(args, val_tag, batch_size=1): if val_tag == 0: data_label = 'columbia' val_data_loader = DataLoader(ValColumbia(args), batch_size=batch_size, shuffle=False, num_workers=0) elif val_tag == 1: data_label = 'coverage' val_data_loader = DataLoader(ValCoverage(args), batch_size=batch_size, shuffle=False, num_workers=0) elif val_tag == 2: data_label = 'casia' val_data_loader = DataLoader(ValCasia(args), batch_size=batch_size, shuffle=False, num_workers=0) elif val_tag == 3: data_label = 'NIST16' val_data_loader = DataLoader(ValNIST16(args), batch_size=batch_size, shuffle=False, num_workers=0) elif val_tag == 4: data_label = 'IMD2020' val_data_loader = DataLoader(ValIMD2020(args), batch_size=batch_size, shuffle=False, num_workers=0) return val_data_loader, data_label ================================================ FILE: LICENSE ================================================ MIT License Copyright (c) 2023 Xiao Guo Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: README.md ================================================ # HiFi_IFDL This is the source code for our CVPR $2023$: "*Hierarchical Fine-Grained Image Forgery Detection and Localization*." [[Arxiv]](https://arxiv.org/pdf/2303.17111.pdf) Authors: [Xiao Guo](https://scholar.google.com/citations?user=Gkc-lAEAAAAJ&hl=en), [Xiaohong Liu](https://jhc.sjtu.edu.cn/~xiaohongliu/), [Zhiyuan Ren](https://scholar.google.com/citations?user=Z1ltuXEAAAAJ&hl=en), [Steven Grosz](https://scholar.google.com/citations?user=I1wOjTYUyYAC&hl=en), [Iacopo Masi](https://iacopomasi.github.io/), [Xiaoming Liu](http://cvlab.cse.msu.edu/)

drawing

### Updates. - [Sep 2024] 👏 The International Journal of Computer Vision (**IJCV**) has accepted the extended version of HiFi-Net, stay tuned~ - [Aug 2024] The HiFi-Net is integrated into the DeepFake-o-meter v2.0 platform, which is a user-friendly public detection tool designed by the **University at Buffalo**. [[DeepFake-o-meter v2.0]](https://zinc.cse.buffalo.edu/ubmdfl/deep-o-meter/home_login) [[ArXiv]](https://arxiv.org/pdf/2404.13146) - [Jul. 2024] 👏 **ECCV2024** "Deepfake Explainer" paper [[ArXiv]](https://arxiv.org/pdf/2402.00126) reports HiFi-Net's deep fake detection performance and the source code is released [[link]](https://github.com/CHELSEA234/HiFi_IFDL/edit/main/applications/deepfake_detection). - [Sep 2023] The first version dataset can be acquired via this link: [Dataset Link](https://drive.google.com/drive/folders/1fwBEmW30-e0ECpCNNG3nRU6I9OqJfMAn?usp=sharing) - [June 2023] The extended version of our work has been submitted to one of the ~~Machine Learning Journals~~ IJCV. - **This GitHub will keep updated, please stay tuned~** ### Short 5 Min Video [![Please Click the Figure](https://github.com/CHELSEA234/HiFi_IFDL/blob/main/figures/architecture.png)](https://www.youtube.com/watch?v=FwS3X5xcj8A&list=LL&index=5) ### Usage on Manipulation Localization (_e.g._, Columbia, Coverage, CASIA, NIST16 and IMD2020) - To create your environment by ``` conda env create -f environment.yml ``` or mannually install `pytorch 1.11.0` and `torchvision 0.12.0` in `python 3.7.16`. - Go to [localization_weights_link](https://drive.google.com/drive/folders/1cxCoE2hjcDj4lLrJmGEbskzPRJfoDIMJ?usp=sharing) to download the weights from, and then put them in `weights`. - To apply the pre-trained model on images in the `./data_dir` and then obtain results in `./viz_eval`, please run ``` bash HiFi_Net_loc.sh ``` - More quantitative and qualitative results can be found at: [csv](https://drive.google.com/drive/folders/12iS0ILb6ndXtdWjonByrgnejzuAvwCqp?usp=sharing) and [qualitative results](https://drive.google.com/drive/folders/1iZp6ciOHSbGq4EsC_AYl7zVK24gBtrd1?usp=sharing). - If you would like to generate the above result. Download $5$ datasets via [link](https://drive.google.com/file/d/1RYXTg0Q82KEvkeOtaaR5AZ0FBx5219SY/view?usp=sharing) and unzip it by `tar -xvf data.tar.gz`. Then, uncomment this [line](https://github.com/CHELSEA234/HiFi_IFDL/blob/main/utils/load_edata.py#L21) and run `HiFi_Net_loc.sh`. ### Usage on Detecting and Localization for the general forged content including GAN and diffusion-generated images: - This reproduces detection and localization results in the HiFi-IFDL dataset (Tab. 2 and Supplementary Fig.1) - Go to [HiFi_IFDL_weights_link](https://drive.google.com/drive/folders/1v07aJ2hKmSmboceVwOhPvjebFMJFHyhm?usp=sharing) to download the weights, and then put them in `weights`. - The quick usage on HiFi_Net: ```python from HiFi_Net import HiFi_Net from PIL import Image import numpy as np HiFi = HiFi_Net() # initialize img_path = 'asset/sample_1.jpg' ## detection res3, prob3 = HiFi.detect(img_path) # print(res3, prob3) 1 1.0 HiFi.detect(img_path, verbose=True) ## localization binary_mask = HiFi.localize(img_path) binary_mask = Image.fromarray((binary_mask*255.).astype(np.uint8)) binary_mask.save('pred_mask.png') ``` ### Quick Start of Source Code A quick view of the code structure: ```bash ./HiFi_IFDL ├── HiFi_Net_loc.py (localization files) ├── HiFi_Net_loc.sh (localization evaluation) ├── HiFi_Net.py (API for the user input image.) ├── IMD_dataloader.py (call dataloaders in the utils folder) ├── model (model module folder) │ ├── NLCDetection_pconv.py (partial convolution, localization, and classification modules) │ ├── seg_hrnet.py (feature extractor based on HRNet) │ ├── LaPlacianMs.py (laplacian filter on the feature map) │ ├── GaussianSmoothing.py (self-made smoothing functions) │ └── ... ├── utils (utils, dataloader, and localization loss class.) │ ├── custom_loss.py (localization loss class and the real pixel center initialization) │ ├── utils.py │ ├── load_data.py (loading training and val dataset.) │ └── load_edata.py (loading inference dataset.) ├── asset (folder contains sample images with their ground truth and predictions.) ├── weights (put the pre-trained weights in.) ├── center (The pre-computed `.pth` file for the HiFi-IFDL dataset.) └── center_loc (The pre-computed `.pth` file for the localization task (Tab.3 in the paper).) ``` ### Question and Answers. Q1. Why train and val datasets are in the same path? A1. For each forgery method, we save both train and val in the SAME folder, from which we use a text file to obtain the training and val images. The text file contains a list of image names, and the first `val_num` are used for training and the last "val_num" for validation. Specifically, refer to [code](https://github.com/CHELSEA234/HiFi_IFDL/blob/main/utils/load_data.py#L271) for details. What is more, we build up the code on the top of the PSCC-Net, which adapts the same style of loading data, please compare [code1](https://github.com/proteus1991/PSCC-Net/blob/main/utils/load_tdata.py#L88) with [code2](https://github.com/proteus1991/PSCC-Net/blob/main/utils/load_tdata.py#L290). Q2. What is the dataset naming for STGAN and the face-shifter section? A2. Please check the STGAN.txt in this [link](https://drive.google.com/drive/folders/1OIUv7OGxfAyerMnmKvrNnN_5CmIDcNxo?usp=sharing), which contains all manipulated/modified images we have used for training and validation. This txt file will be loaded by this line of [code](https://github.com/CHELSEA234/HiFi_IFDL/blob/main/utils/load_data.py#L163), which says about the corresponding masks. Lastly, I am not sure if I have release the authentic images, if I do not, you can simply find them in the public celebAHQ dataset. I will try to offer the rigid naming for the dataset in the near future. ### Reference If you would like to use our work, please cite: ```Bibtex @inproceedings{hifi_net_xiaoguo, author = { Xiao Guo and Xiaohong Liu and Zhiyuan Ren and Steven Grosz and Iacopo Masi and Xiaoming Liu }, title = { Hierarchical Fine-Grained Image Forgery Detection and Localization }, booktitle = { CVPR }, year = { 2023 }, } ``` ================================================ FILE: applications/CNNImage_detection/README.md ================================================ ================================================ FILE: applications/DiffVideo_detection/README.md ================================================ ================================================ FILE: applications/deepfake_detection/FF++/put_weight_here ================================================ ================================================ FILE: applications/deepfake_detection/README.md ================================================ # HiFi_Deepfake We apply the HiFi_Net for the deepfake detection as the following diagram:

drawing

### Reported Performance
| Dataset | AUC | Accuracy | EER | TPR@FPR=**$10$**% |TPR@FPR=**$1$**% | |:----:|:----:|:----:|:----:|:----:|:----:| |FF++(c40)|$92.10$|$89.16$|N/A|$74.44$|$40.85$ |CelebDF|$68.80$|$67.20$|$36.13$|N/A|N/A |WildDeepfake|$65.22$|$66.29$|$38.65$|N/A|N/A
More results please refer to the table $3$ of our ECCV2024 paper [[ArXiv]](https://arxiv.org/pdf/2402.00126) ### The Pre-trained Weights and User-friendly Preprocessed Dataset: 1. The pre-trained weights on FF++ can be download via [[link]](https://drive.google.com/drive/folders/1AElYlVxsahgGIua3m3Kj2VhSc3S7ADLJ?usp=sharing) 2. We offer a preprocessed FF++ dataset in the HDF5 file format [[link]](https://drive.google.com/drive/folders/1ovuurFCkBfmcMq7HKO5ph36U1QyL75UA?usp=sharing), supporting faster I/O. The dataset follows the naming ```FF++_{manipulation_type}_{compression rate}.h5``` and is structured as follows: ``` FF++_Deepfakes_c23.h5: FF++_Deepfakes_c40.h5 FF++_Face2Face_c23.h5 FF++_Face2Face_c40.h5 ... ``` ### Quick Start 1. Setup the environment using ```environment.yml```, then put the pre-trained weights in ```FF++``` folder. 2. Download the entire dataset or a small portion of datasets, for example ```FF++_original_c40.h5``` and ```FF++_Deepfakes_c40.h5```. 3. Run `bash test.sh` after setting up the data path [here](https://github.com/CHELSEA234/HiFi_IFDL/blob/main/applications/deepfake_detection/test.py#L106). 4. If you choose to run the small portion dataset (e.g., ```FF++_original_c40.h5``` and ```FF++_Deepfakes_c40.h5```), please comment this [link](https://github.com/CHELSEA234/HiFi_IFDL/blob/main/applications/deepfake_detection/test.py#L34) ### Quick View of Code ```bash ./deepfake_detection ├── test.py (the inference code.) ├── test.sh (run the inference code.) ├── dataset_test.py (dataset tutorial) ├── dataset_test.sh (dataset tutorial) ├── train.py (the train code.) ├── train.sh (run the train code.) ├── exp_FF_c40_bs_32_lr_0.0001_ws_10.txt (The training log file.) ├── FF++ (Please download the pre-trained weights and put it here) ├── sequence (model module folder) │ ├── rnn_stratified_dataloader.py (datalaoder) │ ├── runjobs_utils.py (the first utility) │ ├── torch_utils.py (the second utility) │ └── models │ ├── run_model.sh (model tutorial) │ ├── LaPlacianMs.py │ ├── HiFiNet_deepfake.py │ └── ... └── environment.yml ``` ================================================ FILE: applications/deepfake_detection/dataset_test.py ================================================ # coding: utf-8 # author: Hierarchical Fine-Grained Image Forgery Detection and Localization import os import numpy as np import subprocess import logging import sys import torch import torch.nn as nn import torch.nn.functional as F import argparse import datetime from tensorboardX import SummaryWriter from torch.optim.lr_scheduler import ReduceLROnPlateau source_path = os.path.join('./sequence') sys.path.append(source_path) from rnn_stratified_dataloader import get_dataloader from models.HiFiNet_deepfake import HiFiNet_deepfake from torch_utils import eval_model,display_eval_tb,train_logging,lrSched_monitor from runjobs_utils import init_logger,Saver,DataConfig,torch_load_model logger = init_logger(__name__) logger.setLevel(logging.INFO) starting_time = datetime.datetime.now() ## Deterministic training _seed_id = 100 torch.backends.cudnn.deterministic = True torch.manual_seed(_seed_id) datasets = ['original', 'Deepfakes', 'FaceSwap', 'NeuralTextures', 'Face2Face'] # datasets = ['original', 'Deepfakes'] manipulations_names = [n for c, n in enumerate(datasets) if n != 'original'] manipulations_dict = {n : c for c, n in enumerate(manipulations_names) } manipulations_dict['original'] = 255 for key, value in manipulations_dict.items(): print(key, value) ctype = 'c40' # Create the parser parser = argparse.ArgumentParser(description='Process some integers.') parser.add_argument('--batch_size', type=int, default=4, help='input batch size for training (default: 32)') parser.add_argument('--window_size', type=int, default=5, help='size of the sliding window (default: 5)') parser.add_argument('--dataset_name', type=str, default="FF++", help='size of the sliding window (default: 5)') parser.add_argument('--gpus', type=int, default=4, help='input batch size for training (default: 32)') parser.add_argument('--feat_dim', type=int, default=270, help='input dim to rnn. (default: 32)') parser.add_argument('--valid_epoch', type=int, default=2, help='val epoch') parser.add_argument('--display_step', type=int, default=50, help='display the loss value.') parser.add_argument('--learning_rate', type=float, default=1e-3, help='the used learning rate') # Parse the arguments args = parser.parse_args() ## Hyper-params ####################### hparams = { 'epochs': 50, 'batch_size': args.batch_size, 'basic_lr': args.learning_rate, 'fine_tune': True, 'use_laplacian': True, 'step_factor': 0.1, 'patience': 20, 'weight_decay': 1e-06, 'lr_gamma': 2.0, 'use_magic_loss': True, 'feat_dim': args.feat_dim, 'drop_rate': 0.2, 'skip_valid': False, 'rnn_type': 'LSTM', 'rnn_hidden_size': 256, 'num_rnn_layers': 1, 'rnn_drop_rate': 0.2, 'bidir': False, 'merge_mode': 'concat', 'perc_margin_1': 0.95, 'perc_margin_2': 0.95, 'soft_boundary': False, 'dist_p': 2, 'radius_param': 0.84, 'strat_sampling': True, 'normalize': True, 'window_size': args.window_size, 'hop': 1, 'valid_epoch': args.valid_epoch, 'display_step': args.display_step, 'use_sched_monitor': True } batch_size = hparams['batch_size'] basic_lr = hparams['basic_lr'] fine_tune = hparams['fine_tune'] use_laplacian = hparams['use_laplacian'] step_factor = hparams['step_factor'] patience = hparams['patience'] weight_decay = hparams['weight_decay'] lr_gamma = hparams['lr_gamma'] use_magic_loss = hparams['use_magic_loss'] feat_dim = hparams['feat_dim'] drop_rate = hparams['drop_rate'] rnn_type = hparams['rnn_type'] rnn_hidden_size = hparams['rnn_hidden_size'] num_rnn_layers = hparams['num_rnn_layers'] rnn_drop_rate = hparams['rnn_drop_rate'] bidir = hparams['bidir'] merge_mode = hparams['merge_mode'] perc_margin_1 = hparams['perc_margin_1'] perc_margin_2 = hparams['perc_margin_2'] dist_p = hparams['dist_p'] radius_param = hparams['radius_param'] strat_sampling = hparams['strat_sampling'] normalize = hparams['normalize'] window_size = hparams['window_size'] hop = hparams['hop'] soft_boundary = hparams['soft_boundary'] use_sched_monitor = hparams['use_sched_monitor'] ######################################## workers_per_gpu = 6 dataset_name = f"{args.dataset_name}" exp_name = f"05_exp_c40_bs_{batch_size}_lr_{basic_lr}_ws_{window_size}" model_name = exp_name model_path = os.path.join(f'./{dataset_name}', model_name) print(f'Window_size: {args.window_size}; Dataset: {dataset_name}; Batch_Size: {batch_size}; LR: {basic_lr}.') # Create the model path if doesn't exists if not os.path.exists(model_path): subprocess.call(f"mkdir -p {model_path}", shell=True) ## Data Generation img_path = "/user/guoxia11/cvlshare/cvl-guoxia11/FaceForensics_HiFiNet" balanced_minibatch_opt = True if dataset_name == 'FF++': train_generator, train_dataset = get_dataloader( img_path, datasets, ctype, manipulations_dict, window_size, hop, use_laplacian, normalize, strat_sampling, balanced_minibatch_opt, 'train', batch_size, workers=workers_per_gpu*args.gpus ) test_generator, test_dataset = get_dataloader( img_path, datasets, ctype, manipulations_dict, window_size, hop, use_laplacian, normalize, strat_sampling, False, 'test', batch_size, workers=workers_per_gpu*args.gpus ) # print("the dataset length is: ", len(train_dataset)) print("the dataloader length is: ", len(train_generator)) # del train_dataset # del test_dataset elif dataset_name == "CelebDF": pass ## TODO: will be released in the near future. elif dataset_name == 'DFW': pass ## TODO: will be released in the near future. print('train: ', len(train_generator), len(train_dataset)) print('test: ', len(test_generator), len(test_dataset)) for ib, (img_batch_mmodal, true_labels, manip_type) in enumerate(train_generator,1): print(img_batch_mmodal.size(), true_labels.size(), manip_type[:2]) if ib == 1: break for ib, (img_batch_mmodal, true_labels, manip_type) in enumerate(test_generator,1): print(ib, img_batch_mmodal.size(), true_labels.size(), manip_type[:2]) if ib == 1: break print("...over...") ================================================ FILE: applications/deepfake_detection/dataset_test.sh ================================================ source ~/.bashrc conda activate HiFi_Net_deepfake CUDA_NUM=7 CUDA_VISIBLE_DEVICES=$CUDA_NUM python dataset_test.py \ --dataset_name FF++ \ --batch_size 32 \ --window_size 10 \ --gpus 1 \ --valid_epoch 1 \ --feat_dim 1000 \ --learning_rate 1e-4 \ --display_step 100 ================================================ FILE: applications/deepfake_detection/environment.yml ================================================ name: HiFi_Net_deepfake channels: - pytorch - conda-forge - defaults dependencies: - _libgcc_mutex=0.1=main - _openmp_mutex=5.1=1_gnu - absl-py=1.3.0=py37h06a4308_0 - aiohttp=3.8.3=py37h5eee18b_0 - aiosignal=1.2.0=pyhd3eb1b0_0 - async-timeout=4.0.2=py37h06a4308_0 - asynctest=0.13.0=py_0 - attrs=22.1.0=py37h06a4308_0 - blas=1.0=mkl - blinker=1.4=py37h06a4308_0 - brotlipy=0.7.0=py37h27cfd23_1003 - bzip2=1.0.8=h7b6447c_0 - c-ares=1.19.1=h5eee18b_0 - ca-certificates=2023.12.12=h06a4308_0 - cachetools=4.2.2=pyhd3eb1b0_0 - certifi=2022.12.7=py37h06a4308_0 - cffi=1.15.1=py37h5eee18b_3 - charset-normalizer=2.0.4=pyhd3eb1b0_0 - click=8.0.4=py37h06a4308_0 - cryptography=39.0.1=py37h9ce1e76_0 - cudatoolkit=11.3.1=h2bc3f7f_2 - cycler=0.11.0=pyhd3eb1b0_0 - ffmpeg=4.3=hf484d3e_0 - fftw=3.3.9=h27cfd23_1 - freetype=2.12.1=h4a9f257_0 - frozenlist=1.3.3=py37h5eee18b_0 - giflib=5.2.1=h5eee18b_3 - gmp=6.2.1=h295c915_3 - gnutls=3.6.15=he1e5248_0 - google-auth=2.6.0=pyhd3eb1b0_0 - google-auth-oauthlib=0.4.4=pyhd3eb1b0_0 - grpcio=1.42.0=py37hce63b2e_0 - icu=67.1=he1b5a44_0 - idna=3.4=py37h06a4308_0 - imageio=2.9.0=pyhd3eb1b0_0 - importlib-metadata=4.11.3=py37h06a4308_0 - intel-openmp=2021.4.0=h06a4308_3561 - joblib=1.1.0=pyhd3eb1b0_0 - jpeg=9e=h5eee18b_1 - kiwisolver=1.4.4=py37h6a678d5_0 - lame=3.100=h7b6447c_0 - lcms2=2.12=h3be6417_0 - ld_impl_linux-64=2.38=h1181459_1 - lerc=3.0=h295c915_0 - libblas=3.9.0=12_linux64_mkl - libcblas=3.9.0=12_linux64_mkl - libdeflate=1.17=h5eee18b_1 - libffi=3.4.4=h6a678d5_0 - libgcc-ng=11.2.0=h1234567_1 - libgfortran-ng=11.2.0=h00389a5_1 - libgfortran5=11.2.0=h1234567_1 - libgomp=11.2.0=h1234567_1 - libiconv=1.16=h7f8727e_2 - libidn2=2.3.4=h5eee18b_0 - libpng=1.6.39=h5eee18b_0 - libprotobuf=3.20.3=he621ea3_0 - libstdcxx-ng=11.2.0=h1234567_1 - libtasn1=4.19.0=h5eee18b_0 - libtiff=4.5.1=h6a678d5_0 - libunistring=0.9.10=h27cfd23_0 - libuv=1.44.2=h5eee18b_0 - libwebp=1.2.4=h11a3e52_1 - libwebp-base=1.2.4=h5eee18b_1 - lz4-c=1.9.4=h6a678d5_0 - markdown=3.4.1=py37h06a4308_0 - markupsafe=2.1.1=py37h7f8727e_0 - matplotlib=3.2.2=1 - matplotlib-base=3.2.2=py37h1d35a4c_1 - mkl=2021.4.0=h06a4308_640 - mkl-service=2.4.0=py37h7f8727e_0 - mkl_fft=1.3.1=py37hd3c417c_0 - mkl_random=1.2.2=py37h51133e4_0 - multidict=6.0.2=py37h5eee18b_0 - ncurses=6.4=h6a678d5_0 - nettle=3.7.3=hbbd107a_1 - numpy=1.21.5=py37h6c91a56_3 - numpy-base=1.21.5=py37ha15fc14_3 - oauthlib=3.2.1=py37h06a4308_0 - openh264=2.1.1=h4ff587b_0 - openssl=1.1.1w=h7f8727e_0 - pillow=9.4.0=py37h6a678d5_0 - pip=23.3.2=pyhd8ed1ab_0 - protobuf=3.20.3=py37h6a678d5_0 - pyasn1=0.4.8=pyhd3eb1b0_0 - pyasn1-modules=0.2.8=py_0 - pycparser=2.21=pyhd3eb1b0_0 - pyjwt=2.4.0=py37h06a4308_0 - pyopenssl=23.0.0=py37h06a4308_0 - pyparsing=3.0.9=py37h06a4308_0 - pysocks=1.7.1=py37_1 - python=3.7.16=h7a1cb2a_0 - python-dateutil=2.8.2=pyhd3eb1b0_0 - python_abi=3.7=2_cp37m - pytorch=1.11.0=py3.7_cuda11.3_cudnn8.2.0_0 - pytorch-mutex=1.0=cuda - pyyaml=6.0=py37h5eee18b_1 - readline=8.2=h5eee18b_0 - requests=2.28.1=py37h06a4308_0 - requests-oauthlib=1.3.0=py_0 - rsa=4.7.2=pyhd3eb1b0_1 - scikit-learn=1.0.2=py37hf9e9bfc_0 - scipy=1.7.3=py37h6c91a56_2 - setuptools=68.2.2=pyhd8ed1ab_0 - six=1.16.0=pyhd3eb1b0_1 - sqlite=3.41.2=h5eee18b_0 - tensorboard=2.10.0=py37h06a4308_0 - tensorboard-data-server=0.6.1=py37h52d8a92_0 - tensorboard-plugin-wit=1.8.1=py37h06a4308_0 - threadpoolctl=2.2.0=pyh0d69192_0 - tk=8.6.12=h1ccaba5_0 - torchvision=0.12.0=py37_cu113 - tornado=5.1.1=py37h7b6447c_0 - tqdm=4.64.1=py37h06a4308_0 - typing-extensions=4.3.0=py37h06a4308_0 - typing_extensions=4.3.0=py37h06a4308_0 - urllib3=1.26.14=py37h06a4308_0 - werkzeug=2.2.2=py37h06a4308_0 - wheel=0.38.4=py37h06a4308_0 - xz=5.4.5=h5eee18b_0 - yacs=0.1.6=pyhd3eb1b0_1 - yaml=0.2.5=h7b6447c_0 - yarl=1.8.1=py37h5eee18b_0 - zipp=3.11.0=py37h06a4308_0 - zlib=1.2.13=h5eee18b_0 - zstd=1.5.5=hc292b87_0 - pip: - einops==0.6.1 - h5py==3.8.0 - kmeans-pytorch==0.3 - opencv-python==4.8.1.78 - packaging==24.0 - tensorboardx==2.6.2.2 ================================================ FILE: applications/deepfake_detection/exp_FF_c40_bs_32_lr_0.0001_ws_10.txt ================================================ AUC: 0.8829070609725371 Best Accuracy: 0.8590476190476191 (Threshold: 0.46431525609451324) TPR at FPR=10.0%: 0.6581349206349206 (Score: 0.9032643437385559) TPR at FPR=1.0%: 0.33174603174603173 (Score: 0.9792982339859009) Average Loss: 0.3208117520030077 ####################################################################################################AUC: 0.8959469482237339 Best Accuracy: 0.8698412698412699 (Threshold: 0.4833123738989796) TPR at FPR=10.0%: 0.7043650793650794 (Score: 0.9370318651199341) TPR at FPR=1.0%: 0.35694444444444445 (Score: 0.9900425672531128) Average Loss: 0.3183356274089535 ####################################################################################################AUC: 0.8979908352229781 Best Accuracy: 0.8706349206349207 (Threshold: 0.057323044208089015) TPR at FPR=10.0%: 0.709920634920635 (Score: 0.7713479399681091) TPR at FPR=1.0%: 0.3773809523809524 (Score: 0.9895150661468506) Average Loss: 0.44491299368715304 ####################################################################################################AUC: 0.9030002047115142 Best Accuracy: 0.8752380952380953 (Threshold: 0.1729503843006701) TPR at FPR=10.0%: 0.7001984126984127 (Score: 0.9017165899276733) TPR at FPR=1.0%: 0.4263888888888889 (Score: 0.9916896820068359) Average Loss: 0.3399755131538371 ####################################################################################################AUC: 0.8975945609725371 Best Accuracy: 0.8757142857142857 (Threshold: 0.2071308652196435) TPR at FPR=10.0%: 0.6819444444444445 (Score: 0.9187996983528137) TPR at FPR=1.0%: 0.3384920634920635 (Score: 0.9935241937637329) Average Loss: 0.3518789753537663 ####################################################################################################AUC: 0.8932506613756613 Best Accuracy: 0.8668253968253968 (Threshold: 0.031224684557531163) TPR at FPR=10.0%: 0.6708333333333333 (Score: 0.5181991457939148) TPR at FPR=1.0%: 0.3732142857142857 (Score: 0.973039448261261) Average Loss: 0.7130112742277117 ####################################################################################################AUC: 0.9064488063744018 Best Accuracy: 0.8771428571428571 (Threshold: 0.41749477599181195) TPR at FPR=10.0%: 0.7198412698412698 (Score: 0.9881836771965027) TPR at FPR=1.0%: 0.35535714285714287 (Score: 0.9986361861228943) Average Loss: 0.36584154823334025 ####################################################################################################AUC: 0.8815618701184177 Best Accuracy: 0.8687301587301587 (Threshold: 0.472512484058953) TPR at FPR=10.0%: 0.629563492063492 (Score: 0.9977478384971619) TPR at FPR=1.0%: 0.2505952380952381 (Score: 0.9991620779037476) Average Loss: 0.4665799030846784 ####################################################################################################AUC: 0.8969081475182665 Best Accuracy: 0.8742857142857143 (Threshold: 0.3673135946369051) TPR at FPR=10.0%: 0.6833333333333333 (Score: 0.9759820699691772) TPR at FPR=1.0%: 0.2388888888888889 (Score: 0.9987502098083496) Average Loss: 0.36425455615189173 ####################################################################################################AUC: 0.8948829207608969 Best Accuracy: 0.8776190476190476 (Threshold: 0.6733377333917661) TPR at FPR=10.0%: 0.6797619047619048 (Score: 0.9928593635559082) TPR at FPR=1.0%: 0.21805555555555556 (Score: 0.9992702603340149) Average Loss: 0.38597667283144527 ####################################################################################################AUC: 0.8982062547241119 Best Accuracy: 0.8712698412698413 (Threshold: 0.015597607697074736) TPR at FPR=10.0%: 0.7041666666666667 (Score: 0.6901902556419373) TPR at FPR=1.0%: 0.36468253968253966 (Score: 0.9896582961082458) Average Loss: 0.6102730501254047 ####################################################################################################AUC: 0.9000834593096498 Best Accuracy: 0.8765079365079365 (Threshold: 0.1510543103770636) TPR at FPR=10.0%: 0.7001984126984127 (Score: 0.9870874881744385) TPR at FPR=1.0%: 0.2859126984126984 (Score: 0.9993632435798645) Average Loss: 0.41046855883994304 ####################################################################################################AUC: 0.8990011652809271 Best Accuracy: 0.8803174603174603 (Threshold: 0.10573415606968273) TPR at FPR=10.0%: 0.6757936507936508 (Score: 0.9892104864120483) TPR at FPR=1.0%: 0.2892857142857143 (Score: 0.9993498921394348) Average Loss: 0.41615531263343497 ####################################################################################################AUC: 0.9041319444444444 Best Accuracy: 0.8761904761904762 (Threshold: 0.04045753150121847) TPR at FPR=10.0%: 0.7011904761904761 (Score: 0.9294343590736389) TPR at FPR=1.0%: 0.32063492063492066 (Score: 0.9988333582878113) Average Loss: 0.44224551875087514 ####################################################################################################AUC: 0.8955598072562357 Best Accuracy: 0.8823809523809524 (Threshold: 0.08566011418851711) TPR at FPR=10.0%: 0.6718253968253968 (Score: 0.9880918264389038) TPR at FPR=1.0%: 0.27996031746031746 (Score: 0.9995488524436951) Average Loss: 0.46919996677150333 ####################################################################################################AUC: 0.9041175359032503 Best Accuracy: 0.8798412698412699 (Threshold: 0.13584205501212096) TPR at FPR=10.0%: 0.7011904761904761 (Score: 0.9583638906478882) TPR at FPR=1.0%: 0.24841269841269842 (Score: 0.9993153810501099) Average Loss: 0.40308997611149433 ####################################################################################################AUC: 0.8985135582010583 Best Accuracy: 0.8792063492063492 (Threshold: 0.0554036657163878) TPR at FPR=10.0%: 0.6716269841269841 (Score: 0.9639698266983032) TPR at FPR=1.0%: 0.2152777777777778 (Score: 0.9995352029800415) Average Loss: 0.5122716583750035 ####################################################################################################AUC: 0.9058038863693626 Best Accuracy: 0.8850793650793651 (Threshold: 0.13075093828225343) TPR at FPR=10.0%: 0.7218253968253968 (Score: 0.9681994915008545) TPR at FPR=1.0%: 0.2623015873015873 (Score: 0.999649167060852) Average Loss: 0.43187202120434404 ####################################################################################################AUC: 0.8971601788863695 Best Accuracy: 0.8804761904761905 (Threshold: 0.045349182414935775) TPR at FPR=10.0%: 0.6609126984126984 (Score: 0.9611561894416809) TPR at FPR=1.0%: 0.2396825396825397 (Score: 0.9993257522583008) Average Loss: 0.5295804329411491 ####################################################################################################AUC: 0.9008590010078106 Best Accuracy: 0.8763492063492063 (Threshold: 0.40709965036355367) TPR at FPR=10.0%: 0.7013888888888888 (Score: 0.9943603873252869) TPR at FPR=1.0%: 0.28115079365079365 (Score: 0.9997544884681702) Average Loss: 0.43053449967426666 ####################################################################################################AUC: 0.9007028691106073 Best Accuracy: 0.8792063492063492 (Threshold: 0.11574143740985) TPR at FPR=10.0%: 0.703968253968254 (Score: 0.9931934475898743) TPR at FPR=1.0%: 0.20615079365079364 (Score: 0.9997492432594299) Average Loss: 0.45830285182233954 ####################################################################################################AUC: 0.8918712207105064 Best Accuracy: 0.871904761904762 (Threshold: 0.1257552100813786) TPR at FPR=10.0%: 0.6880952380952381 (Score: 0.9957960844039917) TPR at FPR=1.0%: 0.25416666666666665 (Score: 0.9998204112052917) Average Loss: 0.5295193400679405 ####################################################################################################AUC: 0.8913471592340638 Best Accuracy: 0.8784126984126984 (Threshold: 0.33674659807616425) TPR at FPR=10.0%: 0.6422619047619048 (Score: 0.9985455274581909) TPR at FPR=1.0%: 0.24285714285714285 (Score: 0.9997971653938293) Average Loss: 0.47251886236014523 ####################################################################################################AUC: 0.9127528187200807 Best Accuracy: 0.8819047619047619 (Threshold: 0.20754144801032828) TPR at FPR=10.0%: 0.7325396825396825 (Score: 0.9406241774559021) TPR at FPR=1.0%: 0.3759920634920635 (Score: 0.9955971837043762) Average Loss: 0.32738428150521126 ####################################################################################################AUC: 0.9102273872511968 Best Accuracy: 0.8811111111111111 (Threshold: 0.21728508113579234) TPR at FPR=10.0%: 0.7158730158730159 (Score: 0.9340832829475403) TPR at FPR=1.0%: 0.3998015873015873 (Score: 0.9958102703094482) Average Loss: 0.3444241999582326 ####################################################################################################AUC: 0.9138169249181154 Best Accuracy: 0.8822222222222222 (Threshold: 0.1270105143665013) TPR at FPR=10.0%: 0.7426587301587302 (Score: 0.9135438203811646) TPR at FPR=1.0%: 0.4218253968253968 (Score: 0.9957913160324097) Average Loss: 0.36578153728431817 ####################################################################################################AUC: 0.9142961073318218 Best Accuracy: 0.8849206349206349 (Threshold: 0.2570748374511923) TPR at FPR=10.0%: 0.7301587301587301 (Score: 0.9568803310394287) TPR at FPR=1.0%: 0.43353174603174605 (Score: 0.9967682361602783) Average Loss: 0.36079730476172367 ####################################################################################################AUC: 0.9118335065507686 Best Accuracy: 0.8817460317460317 (Threshold: 0.3071371658422034) TPR at FPR=10.0%: 0.7267857142857143 (Score: 0.9777267575263977) TPR at FPR=1.0%: 0.45575396825396824 (Score: 0.9973113536834717) Average Loss: 0.3544913220903623 ####################################################################################################AUC: 0.915317224111867 Best Accuracy: 0.8812698412698413 (Threshold: 0.14671899654303475) TPR at FPR=10.0%: 0.7494047619047619 (Score: 0.9343666434288025) TPR at FPR=1.0%: 0.4027777777777778 (Score: 0.9977713823318481) Average Loss: 0.3866839759710108 ####################################################################################################AUC: 0.9138940066767448 Best Accuracy: 0.8811111111111111 (Threshold: 0.1215629355189854) TPR at FPR=10.0%: 0.7525793650793651 (Score: 0.9413536190986633) TPR at FPR=1.0%: 0.4005952380952381 (Score: 0.9978323578834534) Average Loss: 0.3838400870690425 ####################################################################################################AUC: 0.9158619142101285 Best Accuracy: 0.8819047619047619 (Threshold: 0.1816996639657616) TPR at FPR=10.0%: 0.7444444444444445 (Score: 0.9434211850166321) TPR at FPR=1.0%: 0.40793650793650793 (Score: 0.997951328754425) Average Loss: 0.3850083784537087 ####################################################################################################AUC: 0.9120400289745527 Best Accuracy: 0.8815873015873016 (Threshold: 0.2168000961569648) TPR at FPR=10.0%: 0.7331349206349206 (Score: 0.9838338494300842) TPR at FPR=1.0%: 0.37936507936507935 (Score: 0.998543381690979) Average Loss: 0.36806364380399964 ####################################################################################################AUC: 0.9090080939783322 Best Accuracy: 0.88 (Threshold: 0.08131052524169635) TPR at FPR=10.0%: 0.7238095238095238 (Score: 0.9731644988059998) TPR at FPR=1.0%: 0.35138888888888886 (Score: 0.9985748529434204) Average Loss: 0.41489630684949136 ####################################################################################################AUC: 0.9134412005542958 Best Accuracy: 0.883968253968254 (Threshold: 0.17156062097213787) TPR at FPR=10.0%: 0.7331349206349206 (Score: 0.9726763367652893) TPR at FPR=1.0%: 0.38551587301587303 (Score: 0.9986814856529236) Average Loss: 0.39861634454801015 ####################################################################################################AUC: 0.9126887282690853 Best Accuracy: 0.8826984126984126 (Threshold: 0.3019230767371409) TPR at FPR=10.0%: 0.7279761904761904 (Score: 0.9891262054443359) TPR at FPR=1.0%: 0.3601190476190476 (Score: 0.9989126920700073) Average Loss: 0.3922838481644826 ####################################################################################################AUC: 0.9078669847568657 Best Accuracy: 0.8807936507936508 (Threshold: 0.30185993131420963) TPR at FPR=10.0%: 0.7170634920634921 (Score: 0.9934865236282349) TPR at FPR=1.0%: 0.3327380952380952 (Score: 0.9990170001983643) Average Loss: 0.416769449960152 ####################################################################################################AUC: 0.9052425831443689 Best Accuracy: 0.8809523809523809 (Threshold: 0.3770884994309789) TPR at FPR=10.0%: 0.703968253968254 (Score: 0.9969133138656616) TPR at FPR=1.0%: 0.3051587301587302 (Score: 0.9991865754127502) Average Loss: 0.4331670764465533 ####################################################################################################AUC: 0.9116062767699672 Best Accuracy: 0.8807936507936508 (Threshold: 0.11609432500454799) TPR at FPR=10.0%: 0.7152777777777778 (Score: 0.9900914430618286) TPR at FPR=1.0%: 0.3503968253968254 (Score: 0.9990679621696472) Average Loss: 0.4059253654547324 ####################################################################################################AUC: 0.909878511589821 Best Accuracy: 0.8838095238095238 (Threshold: 0.10603247603914004) TPR at FPR=10.0%: 0.7170634920634921 (Score: 0.986299455165863) TPR at FPR=1.0%: 0.3521825396825397 (Score: 0.9991635084152222) Average Loss: 0.47194165713981096 ####################################################################################################AUC: 0.9116446208112876 Best Accuracy: 0.8836507936507937 (Threshold: 0.25656318423296115) TPR at FPR=10.0%: 0.7313492063492063 (Score: 0.9949676394462585) TPR at FPR=1.0%: 0.37817460317460316 (Score: 0.9992499947547913) Average Loss: 0.4292217055991585 ####################################################################################################AUC: 0.9078786375661376 Best Accuracy: 0.8836507936507937 (Threshold: 0.1410691703695467) TPR at FPR=10.0%: 0.7085317460317461 (Score: 0.993381142616272) TPR at FPR=1.0%: 0.30813492063492065 (Score: 0.999306321144104) Average Loss: 0.4514668861100909 ####################################################################################################AUC: 0.9030014644746788 Best Accuracy: 0.8812698412698413 (Threshold: 0.32175635900079325) TPR at FPR=10.0%: 0.683531746031746 (Score: 0.9981260895729065) TPR at FPR=1.0%: 0.298015873015873 (Score: 0.9993946552276611) Average Loss: 0.47036494121964795 ####################################################################################################AUC: 0.9079246189216428 Best Accuracy: 0.8819047619047619 (Threshold: 0.2916260416192136) TPR at FPR=10.0%: 0.7218253968253968 (Score: 0.9950783252716064) TPR at FPR=1.0%: 0.3113095238095238 (Score: 0.9993897676467896) Average Loss: 0.4342052312062576 ####################################################################################################AUC: 0.9108625440917106 Best Accuracy: 0.8853968253968254 (Threshold: 0.04057335027773752) TPR at FPR=10.0%: 0.7077380952380953 (Score: 0.9875094294548035) TPR at FPR=1.0%: 0.3238095238095238 (Score: 0.999352753162384) Average Loss: 0.49118314668693375 ####################################################################################################AUC: 0.9123256802721089 Best Accuracy: 0.88 (Threshold: 0.19828409675377698) TPR at FPR=10.0%: 0.7353174603174604 (Score: 0.8316733241081238) TPR at FPR=1.0%: 0.36507936507936506 (Score: 0.9861753582954407) Average Loss: 0.3262811797141537 ####################################################################################################AUC: 0.9141698948097758 Best Accuracy: 0.8820634920634921 (Threshold: 0.19807075297263402) TPR at FPR=10.0%: 0.7535714285714286 (Score: 0.8300660252571106) TPR at FPR=1.0%: 0.36884920634920637 (Score: 0.9891144037246704) Average Loss: 0.32706290029395885 ####################################################################################################AUC: 0.9138121220710507 Best Accuracy: 0.8823809523809524 (Threshold: 0.18788503558754022) TPR at FPR=10.0%: 0.7436507936507937 (Score: 0.8511702418327332) TPR at FPR=1.0%: 0.3825396825396825 (Score: 0.9907612204551697) Average Loss: 0.32802260290183655 ####################################################################################################AUC: 0.914654903628118 Best Accuracy: 0.883015873015873 (Threshold: 0.21272393984804353) TPR at FPR=10.0%: 0.7424603174603175 (Score: 0.8644278049468994) TPR at FPR=1.0%: 0.4033730158730159 (Score: 0.9910193085670471) Average Loss: 0.3317383197679889 ####################################################################################################AUC: 0.9151813271604937 Best Accuracy: 0.8817460317460317 (Threshold: 0.27262781311795303) TPR at FPR=10.0%: 0.7418650793650794 (Score: 0.8735067844390869) TPR at FPR=1.0%: 0.4005952380952381 (Score: 0.9921466708183289) Average Loss: 0.32901159345366177 ####################################################################################################AUC: 0.9151310153691107 Best Accuracy: 0.8819047619047619 (Threshold: 0.13756239538235432) TPR at FPR=10.0%: 0.7498015873015873 (Score: 0.8710891008377075) TPR at FPR=1.0%: 0.4158730158730159 (Score: 0.9922192096710205) Average Loss: 0.33308415401604047 #################################################################################################### ================================================ FILE: applications/deepfake_detection/sequence/models/GaussianSmoothing.py ================================================ # author: Hierarchical Fine-Grained Image Forgery Detection and Localization, CVPR2023 import os import math import numbers import random import numpy as np import torch from torch import nn from torch.nn import functional as F class GaussianSmoothing(nn.Module): """ Apply gaussian smoothing on a 1d, 2d or 3d tensor. Filtering is performed seperately for each channel in the input using a depthwise convolution. Arguments: channels (int, sequence): Number of channels of the input tensors. Output will have this number of channels as well. kernel_size (int, sequence): Size of the gaussian kernel. sigma (float, sequence): Standard deviation of the gaussian kernel. dim (int, optional): The number of dimensions of the data. Default value is 2 (spatial). """ def __init__(self, channels, kernel_size, sigma, dim=2): super(GaussianSmoothing, self).__init__() if isinstance(kernel_size, numbers.Number): kernel_size = [kernel_size] * dim if isinstance(sigma, numbers.Number): sigma = [sigma] * dim # The gaussian kernel is the product of the # gaussian function of each dimension. kernel = 1 meshgrids = torch.meshgrid( [ torch.arange(size, dtype=torch.float32) for size in kernel_size ], indexing='ij' ) for size, std, mgrid in zip(kernel_size, sigma, meshgrids): mean = (size - 1) / 2 kernel *= 1 / (std * math.sqrt(2 * math.pi)) * \ torch.exp(-((mgrid - mean) / std) ** 2 / 2) # Make sure sum of values in gaussian kernel equals 1. kernel = kernel / torch.sum(kernel) # Reshape to depthwise convolutional weight kernel = kernel.view(1, 1, *kernel.size()) kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1)) self.register_buffer('weight', kernel) self.groups = channels if dim == 1: self.conv = F.conv1d elif dim == 2: self.conv = F.conv2d elif dim == 3: self.conv = F.conv3d else: raise RuntimeError( 'Only 1, 2 and 3 dimensions are supported. Received {}.'.format(dim) ) def forward(self, input): """ Apply gaussian filter to input. Arguments: input (torch.Tensor): Input to apply gaussian filter on. Returns: filtered (torch.Tensor): Filtered output. """ return self.conv(input, weight=self.weight, groups=self.groups) ================================================ FILE: applications/deepfake_detection/sequence/models/HiFiNet_deepfake.py ================================================ # coding: utf-8 # author: Hierarchical Fine-Grained Image Forgery Detection and Localization import torch import torch.nn as nn import torch.nn.functional as F import sys sys.path.append('./sequence/models') from hrnet.seg_hrnet_config import get_cfg_defaults from hrnet.seg_hrnet import get_seg_model class Flatten(nn.Module): def __init__(self): super(Flatten, self).__init__() def forward(self, x): return x.view(x.size(0), -1) class CatDepth(nn.Module): def __init__(self): super(CatDepth, self).__init__() def forward(self, x, y): return torch.cat([x,y],dim=1) class HiFiNet_deepfake(nn.Module): def __init__(self, use_laplacian=False, drop_rate=0.5, use_magic_loss=True, feat_dim = 1024, pretrained=True, rnn_type='LSTM', rnn_hidden_size=10, num_rnn_layers=1, rnn_drop_rate=0.5, bidir=False, merge_mode='concat',gate_type='sigmoid', device='cuda'): super(HiFiNet_deepfake, self).__init__() self.use_laplacian = use_laplacian self.feat_dim = feat_dim self.rnn_type = rnn_type self.rnn_input_size = feat_dim self.rnn_hidden_size = rnn_hidden_size self.num_rnn_layers = num_rnn_layers self.rnn_drop_rate = rnn_drop_rate self.bidir = bidir self.magic_loss = use_magic_loss self.device = device self.FENet = get_seg_model(get_cfg_defaults()).to(self.device) self.rnn = nn.LSTM(input_size=self.rnn_input_size, hidden_size=self.rnn_hidden_size, num_layers=self.num_rnn_layers, batch_first=False, dropout=self.rnn_drop_rate, bidirectional=self.bidir ) self.output_rnn = nn.Sequential(nn.ReLU(inplace=True), nn.Linear(256, 2)) # Select the merger function if merge_mode == 'concat': self.merger_function = merge_concat elif merge_mode == 'sum': self.merger_function = merge_sum def forward(self,x): batch_size, window_size, _, H, W = x.size() x = x.view(batch_size * window_size, 3, H, W) # Input for RGB branch conv_feat = self.FENet(x) z = conv_feat.view(batch_size, window_size, -1).permute(1,0,2) out, (h,c) = self.rnn(z) out = self.merger_function(out[-1, :, :self.rnn_hidden_size], out[0, :, self.rnn_hidden_size:]) out = self.output_rnn(out) return out def up (self,x, size): return F.interpolate(x,size=size,mode='bilinear', align_corners=False) def up_pix(self,x,r): return F.pixel_shuffle(x,r) ## Functions to merger the bidirectional outputs # Concatenation function def merge_concat(out1, out2): return torch.cat((out1, out2), 1) # Summation function def merge_sum(out1, out2): return torch.add(out1, out2) if __name__ == "__main__": import torch input = torch.randn((4, 1, 3, 224, 224)).cuda() # [64, 10, 3, 224, 224] model = HiFiNet_deepfake(use_laplacian=True, drop_rate=0.2, use_magic_loss=False, pretrained=True, rnn_drop_rate=0.2, feat_dim=1000, rnn_hidden_size=128, num_rnn_layers=2, bidir=True).cuda() model = torch.nn.DataParallel(model) print(f"...comes to this place...") output = model(input) print(f"the model output: ", output.size()) print("...over...") ================================================ FILE: applications/deepfake_detection/sequence/models/LaPlacianMs.py ================================================ # author: Hierarchical Fine-Grained Image Forgery Detection and Localization, CVPR2023 import os import torch import random import numpy as np import torch.nn as nn from torch.nn import functional as F try: from .GaussianSmoothing import GaussianSmoothing except: from GaussianSmoothing import GaussianSmoothing class LaPlacianMs(nn.Module): def __init__(self,in_c,gauss_ker_size=3,scale=[2],drop_rate=0.2): super(LaPlacianMs, self).__init__() self.scale = scale self.gauss_ker_size = gauss_ker_size ## apply gaussian smoothing to input feature maps with 3 planes ## with kernel size K and sigma s self.smoothing = nn.ModuleDict() for s in self.scale: self.smoothing['scale-'+str(s)] = GaussianSmoothing(in_c, self.gauss_ker_size, s) self.conv_1x1 = nn.Sequential(nn.Conv2d(in_c*len(scale), in_c, kernel_size=1, stride=1, bias=False,groups=1), nn.BatchNorm2d(in_c), nn.ReLU(inplace=True), # nn.Dropout(p=drop_rate) ) # Official init from torch repo. for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight) elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) elif isinstance(m, nn.Linear): nn.init.constant_(m.bias, 0) def down(self,x,s): return F.interpolate(x,scale_factor=s, mode='bilinear', align_corners=False) def up (self,x, size): return F.interpolate(x,size=size,mode='bilinear',align_corners=False) def forward(self, x): for i, s in enumerate(self.scale): sm = self.smoothing['scale-'+str(s)](x) sm = self.up(self.down(sm,1/s),(x.shape[2],x.shape[3])) if i == 0: diff = x - sm else: diff = torch.cat((diff, x - sm), dim=1) return self.conv_1x1(diff) ================================================ FILE: applications/deepfake_detection/sequence/models/hrnet/hrnet_w18_small_model_v2.pth ================================================ [File too large to display: 59.8 MB] ================================================ FILE: applications/deepfake_detection/sequence/models/hrnet/seg_hrnet.py ================================================ # ------------------------------------------------------------------------------ # Copyright (c) Microsoft # Licensed under the MIT License. # The script is adopted from Ke Sun (sunk@mail.ustc.edu.cn) # ------------------------------------------------------------------------------ from __future__ import absolute_import from __future__ import division from __future__ import print_function from LaPlacianMs import LaPlacianMs import os import logging import numpy as np import torch import torch.nn as nn import torch._utils import torch.nn.functional as F BN_MOMENTUM = 0.01 logger = logging.getLogger(__name__) # noise generation def srm_generation(image): """ :param image: N * C * H * W :return: noises """ # srm kernel 1 srm1 = np.zeros([5, 5]).astype('float32') srm1[1:-1, 1:-1] = np.array([[-1, 2, -1], [2, -4, 2], [-1, 2, -1]]) srm1 /= 4. # srm kernel 2 srm2 = np.array([[-1, 2, -2, 2, -1], [2, -6, 8, -6, 2], [-2, 8, -12, 8, -2], [2, -6, 8, -6, 2], [-1, 2, -2, 2, -1]]).astype('float32') srm2 /= 12. # srm kernel 3 srm3 = np.zeros([5, 5]).astype('float32') srm3[2, 1:-1] = np.array([1, -2, 1]) srm3 /= 2. srm = np.stack([srm1, srm2, srm3], axis=0) W_srm = np.zeros([3, 3, 5, 5]).astype('float32') for i in range(3): W_srm[i, 0, :, :] = srm[i, :, :] W_srm[i, 1, :, :] = srm[i, :, :] W_srm[i, 2, :, :] = srm[i, :, :] W_srm = torch.from_numpy(W_srm).to(image.get_device()) srm_noise = F.conv2d(image, W_srm, padding=2) return srm_noise # bayar constrained layer class BayarConstraint(object): def __init__(self): pass def __call__(self, module): if hasattr(module, 'weight'): weight = module.weight.data # oc, ic, h, w h, w = weight.size()[2:] mask = torch.zeros_like(weight) mask[:, :, h//2, w//2] = 1 weight *= (1 - mask) rest_sum = torch.sum(weight, dim=(2, 3), keepdim=True) weight /= (rest_sum + 1e-7) weight -= mask module.weight.data = weight def conv3x3(in_planes, out_planes, stride=1): """3x3 convolution with padding""" return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) class CatDepth(nn.Module): def __init__(self): super(CatDepth, self).__init__() def forward(self, x, y): return torch.cat([x,y],dim=1) def weights_init(init_type='gaussian'): def init_fun(m): classname = m.__class__.__name__ if (classname.find('Conv') == 0 or classname.find( 'Linear') == 0) and hasattr(m, 'weight'): if init_type == 'gaussian': nn.init.normal_(m.weight, 0.0, 0.02) elif init_type == 'xavier': nn.init.xavier_normal_(m.weight, gain=math.sqrt(2)) elif init_type == 'kaiming': nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in') elif init_type == 'orthogonal': nn.init.orthogonal_(m.weight, gain=math.sqrt(2)) elif init_type == 'default': pass else: assert 0, "Unsupported initialization: {}".format(init_type) if hasattr(m, 'bias') and m.bias is not None: nn.init.constant_(m.bias, 0.0) return init_fun '''GX: basicblock contains two conv3x3 and two batch norm''' '''GX: at last, it has a residual connection''' class BasicBlock(nn.Module): expansion = 1 def __init__(self, inplanes, planes, stride=1, downsample=None): super(BasicBlock, self).__init__() self.conv1 = conv3x3(inplanes, planes, stride) self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) self.relu = nn.ReLU(inplace=False) self.conv2 = conv3x3(planes, planes) self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) self.downsample = downsample self.stride = stride def forward(self, x): residual = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) if self.downsample is not None: residual = self.downsample(x) out = out + residual out = self.relu(out) return out '''GX: 3 conv + 3 bn then a residual.''' class Bottleneck(nn.Module): expansion = 4 def __init__(self, inplanes, planes, stride=1, downsample=None): super(Bottleneck, self).__init__() self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) self.bn3 = nn.BatchNorm2d(planes * self.expansion, momentum=BN_MOMENTUM) self.relu = nn.ReLU(inplace=False) self.downsample = downsample self.stride = stride def forward(self, x): residual = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) out = self.relu(out) out = self.conv3(out) out = self.bn3(out) if self.downsample is not None: residual = self.downsample(x) out = out + residual out = self.relu(out) return out '''GX: the basic component in the network.''' class HighResolutionModule(nn.Module): def __init__(self, num_branches, blocks, num_blocks, num_inchannels, num_channels, fuse_method, multi_scale_output=True): super(HighResolutionModule, self).__init__() self._check_branches( num_branches, blocks, num_blocks, num_inchannels, num_channels) self.num_inchannels = num_inchannels self.fuse_method = fuse_method self.num_branches = num_branches self.multi_scale_output = multi_scale_output self.branches = self._make_branches( num_branches, blocks, num_blocks, num_channels) self.fuse_layers = self._make_fuse_layers() self.relu = nn.ReLU(inplace=False) def _check_branches(self, num_branches, blocks, num_blocks, num_inchannels, num_channels): if num_branches != len(num_blocks): error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format( num_branches, len(num_blocks)) raise ValueError(error_msg) if num_branches != len(num_channels): error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format( num_branches, len(num_channels)) raise ValueError(error_msg) if num_branches != len(num_inchannels): error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format( num_branches, len(num_inchannels)) raise ValueError(error_msg) def _make_one_branch(self, branch_index, block, num_blocks, num_channels, stride=1): downsample = None if stride != 1 or \ self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion: downsample = nn.Sequential( nn.Conv2d(self.num_inchannels[branch_index], num_channels[branch_index] * block.expansion, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(num_channels[branch_index] * block.expansion, momentum=BN_MOMENTUM), ) layers = [] layers.append(block(self.num_inchannels[branch_index], num_channels[branch_index], stride, downsample)) self.num_inchannels[branch_index] = \ num_channels[branch_index] * block.expansion for i in range(1, num_blocks[branch_index]): layers.append(block(self.num_inchannels[branch_index], num_channels[branch_index])) return nn.Sequential(*layers) def _make_branches(self, num_branches, block, num_blocks, num_channels): branches = [] for i in range(num_branches): branches.append( self._make_one_branch(i, block, num_blocks, num_channels)) return nn.ModuleList(branches) ## GX: fuse layer converts feature maps at different resolution branches ## GX: into the feature map of the new branches' feature map. ## GX: https://zhuanlan.zhihu.com/p/335333233 def _make_fuse_layers(self): if self.num_branches == 1: return None num_branches = self.num_branches num_inchannels = self.num_inchannels fuse_layers = [] for i in range(num_branches if self.multi_scale_output else 1): fuse_layer = [] for j in range(num_branches): if j > i: fuse_layer.append(nn.Sequential( nn.Conv2d(num_inchannels[j], num_inchannels[i], 1, 1, 0, bias=False), nn.BatchNorm2d(num_inchannels[i], momentum=BN_MOMENTUM))) elif j == i: fuse_layer.append(None) else: conv3x3s = [] for k in range(i - j): if k == i - j - 1: num_outchannels_conv3x3 = num_inchannels[i] conv3x3s.append(nn.Sequential( nn.Conv2d(num_inchannels[j], num_outchannels_conv3x3, 3, 2, 1, bias=False), nn.BatchNorm2d(num_outchannels_conv3x3, momentum=BN_MOMENTUM))) else: num_outchannels_conv3x3 = num_inchannels[j] conv3x3s.append(nn.Sequential( nn.Conv2d(num_inchannels[j], num_outchannels_conv3x3, 3, 2, 1, bias=False), nn.BatchNorm2d(num_outchannels_conv3x3, momentum=BN_MOMENTUM), nn.ReLU(inplace=False))) fuse_layer.append(nn.Sequential(*conv3x3s)) fuse_layers.append(nn.ModuleList(fuse_layer)) return nn.ModuleList(fuse_layers) def get_num_inchannels(self): return self.num_inchannels def forward(self, x): if self.num_branches == 1: return [self.branches[0](x[0])] for i in range(self.num_branches): x[i] = self.branches[i](x[i]) x_fuse = [] for i in range(len(self.fuse_layers)): y = x[0] if i == 0 else self.fuse_layers[i][0](x[0]) for j in range(1, self.num_branches): if i == j: y = y + x[j] elif j > i: width_output = x[i].shape[-1] height_output = x[i].shape[-2] y = y + F.interpolate( self.fuse_layers[i][j](x[j]), size=[height_output, width_output], mode='bilinear', align_corners=True) else: y = y + self.fuse_layers[i][j](x[j]) x_fuse.append(self.relu(y)) return x_fuse blocks_dict = { 'BASIC': BasicBlock, 'BOTTLENECK': Bottleneck } ## GX: the HighResolutionNet has 4 stages. ## GX: each stage has one module which is HighResolutionModule. ## GX: HighResolutionModule has 1,2,3,4 branches. ## GX: each stage has a transitional layers in between. class HighResolutionNet(nn.Module): def __init__(self, config, **kwargs): super(HighResolutionNet, self).__init__() # noise conv # self.im_conv = nn.Conv2d(3, 10, kernel_size=3, stride=1, padding=1, bias=False) # self.bayar_conv = nn.Conv2d(3, 3, kernel_size=5, stride=1, padding=2, bias=False) # self.constraints = BayarConstraint() # stem net self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM) self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM) self.relu = nn.ReLU(inplace=False) # # frequency branch # self.conv1fre = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) # self.bn1fre = nn.BatchNorm2d(64, momentum=BN_MOMENTUM) # self.conv2fre = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False) # self.bn2fre = nn.BatchNorm2d(64, momentum=BN_MOMENTUM) # self.laplacian = LaPlacianMs(in_c=64,gauss_ker_size=3,scale=[2,4,8]) # concat self.concat_depth = CatDepth() self.conv_1x1_merge = nn.Sequential(nn.Conv2d(128, 64, kernel_size=1, stride=1, bias=False,groups=2), nn.BatchNorm2d(64), nn.ReLU(inplace=True), nn.Dropout(p=0.2) ) self.conv_1x1_merge.apply(weights_init('kaiming')) self.stage1_cfg = config['STAGE1'] num_channels = self.stage1_cfg['NUM_CHANNELS'][0] block = blocks_dict[self.stage1_cfg['BLOCK']] num_blocks = self.stage1_cfg['NUM_BLOCKS'][0] self.layer1 = self._make_layer(block, 64, num_channels, num_blocks) stage1_out_channel = block.expansion * num_channels self.stage2_cfg = config['STAGE2'] num_channels = self.stage2_cfg['NUM_CHANNELS'] block = blocks_dict[self.stage2_cfg['BLOCK']] num_channels = [ num_channels[i] * block.expansion for i in range(len(num_channels))] self.transition1 = self._make_transition_layer( [stage1_out_channel], num_channels) self.stage2, pre_stage_channels = self._make_stage( self.stage2_cfg, num_channels) self.stage3_cfg = config['STAGE3'] num_channels = self.stage3_cfg['NUM_CHANNELS'] block = blocks_dict[self.stage3_cfg['BLOCK']] num_channels = [ num_channels[i] * block.expansion for i in range(len(num_channels))] self.transition2 = self._make_transition_layer( pre_stage_channels, num_channels) self.stage3, pre_stage_channels = self._make_stage( self.stage3_cfg, num_channels) self.stage4_cfg = config['STAGE4'] num_channels = self.stage4_cfg['NUM_CHANNELS'] block = blocks_dict[self.stage4_cfg['BLOCK']] num_channels = [ num_channels[i] * block.expansion for i in range(len(num_channels))] self.transition3 = self._make_transition_layer( pre_stage_channels, num_channels) self.stage4, pre_stage_channels = self._make_stage( self.stage4_cfg, num_channels, multi_scale_output=True) # last_inp_channels = np.int(np.sum(pre_stage_channels)) # Classification Head self.incre_modules, self.downsamp_modules, \ self.final_layer = self._make_head(pre_stage_channels) self.classifier = nn.Linear(2048, 1000) def _make_head(self, pre_stage_channels): head_block = Bottleneck head_channels = [32, 64, 128, 256] # Increasing the #channels on each resolution # from C, 2C, 4C, 8C to 128, 256, 512, 1024 incre_modules = [] for i, channels in enumerate(pre_stage_channels): incre_module = self._make_layer(head_block, channels, head_channels[i], 1, stride=1) incre_modules.append(incre_module) incre_modules = nn.ModuleList(incre_modules) # downsampling modules downsamp_modules = [] for i in range(len(pre_stage_channels)-1): in_channels = head_channels[i] * head_block.expansion out_channels = head_channels[i+1] * head_block.expansion downsamp_module = nn.Sequential( nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=2, padding=1), nn.BatchNorm2d(out_channels, momentum=BN_MOMENTUM), nn.ReLU(inplace=True) ) downsamp_modules.append(downsamp_module) downsamp_modules = nn.ModuleList(downsamp_modules) final_layer = nn.Sequential( nn.Conv2d( in_channels=head_channels[3] * head_block.expansion, out_channels=2048, kernel_size=1, stride=1, padding=0 ), nn.BatchNorm2d(2048, momentum=BN_MOMENTUM), nn.ReLU(inplace=True) ) return incre_modules, downsamp_modules, final_layer ## GX: one dimension matrix converts pre to pos. ## GX: if channel numbers are equal, pass it directly. ## GX: if channel numbers are different, using conv 3x3. ## GX: https://zhuanlan.zhihu.com/p/335333233 def _make_transition_layer(self, num_channels_pre_layer, num_channels_cur_layer): num_branches_cur = len(num_channels_cur_layer) num_branches_pre = len(num_channels_pre_layer) transition_layers = [] for i in range(num_branches_cur): if i < num_branches_pre: if num_channels_cur_layer[i] != num_channels_pre_layer[i]: transition_layers.append(nn.Sequential( nn.Conv2d(num_channels_pre_layer[i], num_channels_cur_layer[i], 3, 1, 1, bias=False), nn.BatchNorm2d( num_channels_cur_layer[i], momentum=BN_MOMENTUM), nn.ReLU(inplace=False))) else: transition_layers.append(None) else: conv3x3s = [] for j in range(i + 1 - num_branches_pre): inchannels = num_channels_pre_layer[-1] outchannels = num_channels_cur_layer[i] \ if j == i - num_branches_pre else inchannels conv3x3s.append(nn.Sequential( nn.Conv2d( inchannels, outchannels, 3, 2, 1, bias=False), nn.BatchNorm2d(outchannels, momentum=BN_MOMENTUM), nn.ReLU(inplace=False))) transition_layers.append(nn.Sequential(*conv3x3s)) return nn.ModuleList(transition_layers) ## GX: _make_layer creates a conv + bn def _make_layer(self, block, inplanes, planes, blocks, stride=1): downsample = None if stride != 1 or inplanes != planes * block.expansion: downsample = nn.Sequential( nn.Conv2d(inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM), ) layers = [] layers.append(block(inplanes, planes, stride, downsample)) inplanes = planes * block.expansion for i in range(1, blocks): layers.append(block(inplanes, planes)) return nn.Sequential(*layers) def _make_stage(self, layer_config, num_inchannels, multi_scale_output=True): ## GX: num_modules are all 1 in this work. ## GX: light-weight architectures: num_blocks are all 0. ## GX: branch numbers are 2, 3, 4. num_modules = layer_config['NUM_MODULES'] num_branches = layer_config['NUM_BRANCHES'] num_blocks = layer_config['NUM_BLOCKS'] num_channels = layer_config['NUM_CHANNELS'] block = blocks_dict[layer_config['BLOCK']] fuse_method = layer_config['FUSE_METHOD'] modules = [] for i in range(num_modules): # multi_scale_output is only used last module if not multi_scale_output and i == num_modules - 1: reset_multi_scale_output = False else: reset_multi_scale_output = True modules.append( HighResolutionModule(num_branches, block, num_blocks, num_inchannels, num_channels, fuse_method, reset_multi_scale_output) ) num_inchannels = modules[-1].get_num_inchannels() return nn.Sequential(*modules), num_inchannels def forward(self, x): x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.conv2(x) x = self.bn2(x) x = self.relu(x) x = self.layer1(x) x_list = [] for i in range(self.stage2_cfg['NUM_BRANCHES']): if self.transition1[i] is not None: x_list.append(self.transition1[i](x)) else: x_list.append(x) y_list = self.stage2(x_list) x_list = [] for i in range(self.stage3_cfg['NUM_BRANCHES']): if self.transition2[i] is not None: if i < self.stage2_cfg['NUM_BRANCHES']: x_list.append(self.transition2[i](y_list[i])) else: x_list.append(self.transition2[i](y_list[-1])) else: x_list.append(y_list[i]) y_list = self.stage3(x_list) x_list = [] for i in range(self.stage4_cfg['NUM_BRANCHES']): if self.transition3[i] is not None: if i < self.stage3_cfg['NUM_BRANCHES']: x_list.append(self.transition3[i](y_list[i])) else: x_list.append(self.transition3[i](y_list[-1])) else: x_list.append(y_list[i]) y_list = self.stage4(x_list) # Classification Head y = self.incre_modules[0](y_list[0]) for i in range(len(self.downsamp_modules)): y = self.incre_modules[i+1](y_list[i+1]) + \ self.downsamp_modules[i](y) y = self.final_layer(y) if torch._C._get_tracing_state(): y = y.flatten(start_dim=2).mean(dim=2) else: y = F.avg_pool2d(y, kernel_size=y.size() [2:]).view(y.size(0), -1) y = self.classifier(y) return y def init_weights(self, pretrained='',): for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.normal_(m.weight, std=0.001) elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) if os.path.isfile(pretrained): ## GX: official pre-trained dict. pretrained_dict = torch.load(pretrained) print('=> loading HRNet pretrained model {}'.format(pretrained)) model_dict = self.state_dict() ## GX: the current model. nopretrained_dict = {k: v for k, v in model_dict.items()} pretrained_dict_used = {} for k, v in model_dict.items(): pretrained_key = k if pretrained_key not in pretrained_dict.keys(): if 'stage2' in pretrained_key and 'fuse_layers' not in pretrained_key: if 'branches.2' in pretrained_key: pretrained_key = pretrained_key.replace('stage2.0.', 'stage3.0.') elif 'branches.3' in pretrained_key: pretrained_key = pretrained_key.replace('stage2.0.', 'stage4.0.') elif 'stage3' in pretrained_key and 'fuse_layers' not in pretrained_key: pretrained_key = pretrained_key.replace('stage3.0.', 'stage4.0.') elif 'fre' in pretrained_key: pretrained_key = pretrained_key.replace('fre', '') if pretrained_key in pretrained_dict.keys(): pretrained_dict_used[k] = pretrained_dict[pretrained_key] nopretrained_dict.pop(k) print("no pretrain dict length is: ", len(nopretrained_dict)) ## GX: how many parameters you need to train on your own. model_dict.update(pretrained_dict_used) self.load_state_dict(model_dict) else: print(f"{pretrained} does NOT exist.") print(f"Please try to load the pre-trained weights of HR-Net.") import sys;sys.exit(0) def get_seg_model(cfg, **kwargs): model = HighResolutionNet(cfg, **kwargs) model.init_weights(cfg.PRETRAINED) return model ================================================ FILE: applications/deepfake_detection/sequence/models/hrnet/seg_hrnet_config.py ================================================ # ------------------------------------------------------------------------------ # Copyright (c) Microsoft # Licensed under the MIT License. # The script is adopted from Ke Sun (sunk@mail.ustc.edu.cn) # ------------------------------------------------------------------------------ from __future__ import absolute_import from __future__ import division from __future__ import print_function from yacs.config import CfgNode as CN # high_resoluton_net related params for segmentation HRNET = CN() HRNET.PRETRAINED_LAYERS = ['*'] HRNET.STEM_INPLANES = 64 HRNET.FINAL_CONV_KERNEL = 1 HRNET.PRETRAINED = './sequence/models/hrnet/hrnet_w18_small_model_v2.pth' HRNET.STAGE1 = CN() HRNET.STAGE1.NUM_MODULES = 1 HRNET.STAGE1.NUM_BRANCHES = 1 HRNET.STAGE1.NUM_BLOCKS = [2] HRNET.STAGE1.NUM_CHANNELS = [64] HRNET.STAGE1.BLOCK = 'BOTTLENECK' HRNET.STAGE1.FUSE_METHOD = 'SUM' HRNET.STAGE2 = CN() HRNET.STAGE2.NUM_MODULES = 1 HRNET.STAGE2.NUM_BRANCHES = 4 HRNET.STAGE2.NUM_BLOCKS = [2, 2, 2, 2] HRNET.STAGE2.NUM_CHANNELS = [18, 36, 72, 144] HRNET.STAGE2.BLOCK = 'BASIC' HRNET.STAGE2.FUSE_METHOD = 'SUM' HRNET.STAGE3 = CN() HRNET.STAGE3.NUM_MODULES = 1 HRNET.STAGE3.NUM_BRANCHES = 4 HRNET.STAGE3.NUM_BLOCKS = [2, 2, 2, 2] HRNET.STAGE3.NUM_CHANNELS = [18, 36, 72, 144] HRNET.STAGE3.BLOCK = 'BASIC' HRNET.STAGE3.FUSE_METHOD = 'SUM' HRNET.STAGE4 = CN() HRNET.STAGE4.NUM_MODULES = 1 HRNET.STAGE4.NUM_BRANCHES = 4 HRNET.STAGE4.NUM_BLOCKS = [2, 2, 2, 2] HRNET.STAGE4.NUM_CHANNELS = [18, 36, 72, 144] HRNET.STAGE4.BLOCK = 'BASIC' HRNET.STAGE4.FUSE_METHOD = 'SUM' def get_cfg_defaults(): """Get a yacs CfgNode object with default values for my_project.""" # Return a clone so that the defaults will not be altered # This is for the "local variable" use pattern return HRNET.clone() if __name__ == "__main__": print("Hello World!") ================================================ FILE: applications/deepfake_detection/sequence/models/run_model.sh ================================================ source ~/.bashrc conda activate HiFi_Net_deepfake CUDA_NUM=2 CUDA_VISIBLE_DEVICES=$CUDA_NUM python HiFiNet_deepfake.py ================================================ FILE: applications/deepfake_detection/sequence/rnn_stratified_dataloader.py ================================================ # coding: utf-8 # author: Hierarchical Fine-Grained Image Forgery Detection and Localization, CVPR2023 # based on the sample strategy proposed in Two-branch Recurrent Network for Isolating Deepfakes in Videos, ECCV2020 import torch import torchvision import h5py import os import glob import numpy as np import json import numpy as np from torch.utils import data # Image transformation def get_image_transformation(use_laplacian=False, normalize=True): transforms = [] if normalize: transforms.extend( [torchvision.transforms.ToPILImage(), # Next line takes PIL images as input (ToPILImage() preserves the values in the input array or tensor) torchvision.transforms.ToTensor(), # To bring the pixel values in the range [0,1] torchvision.transforms.Normalize(mean=(0.5,0.5,0.5), std=(0.5,0.5,0.5))] ) return torchvision.transforms.Compose(transforms) else: transforms.extend( [torchvision.transforms.ToPILImage(), # Next line takes PIL images as input (ToPILImage() preserves the values in the input array or tensor) torchvision.transforms.ToTensor()] # To bring the pixel values in the range [0,1] ) return torchvision.transforms.Compose(transforms) # Main dataloader def get_dataloader(img_path,train_dataset_names,ctype,manipulations_dict,window_size=10,hop=1,use_laplacian=False,normalize=True,strat_sampling=False,balanced_minibatch=False,mode='train',bs=32,workers=4): """ This is a dataloader for Face Forensics++ dataset stored in HDF5 file format. The structure of the files should be as shown below: filename.h5 -> keys (video names. Ex, 000_003 for manipulated and 000 for original) -> each video will further have 'n' number of frames. f[key][i] to acces 'ith' frame of 'key' video. Example of filename: FF++_Deepfakes_c40.h5, FF++_Face2Face_c23.h5, FF++_original_c0.h5, etc. Parameters ---------- img_path : str The location of h5 files on hard drive. train_dataset_names : list The datasets that are to be loaded. returns ------- out: torch.utils.data.dataloader.DataLoader A generator that can be used to get the required batches of sequential samples of data. Examples -------- img_path = '/research/cvlshare/cvl-guoxia11/FaceForensics++' train_dataset_names = ['original', 'Deepfakes'] ctype = 'c40' manipulations_dict = {0:'Deepfakes',255:'original'} window_size = 10 hop = 5 use_laplacian = True normalize = True strat_sampling = True mode='train' bs=32 workers=0 train_generator = get_dataloader(img_path,train_dataset_names,ctype,manipulations_dict,window_size,hop,use_laplacian,normalize,strat_sampling,mode,bs,workers) """ transform = get_image_transformation(use_laplacian=False, normalize=normalize) params = {'batch_size': bs, 'shuffle': (mode=='train'), 'num_workers': workers, 'drop_last' : (mode=='train') } if mode == 'test' or mode == 'val': strat_sampling = False datalist_dict = get_img_list(img_path, train_dataset_names, ctype, mode, window_size, hop, strat_sampling, balanced_minibatch) datasets = { dataset_key : ForensicFaceDatasetRNN(img_list, img_path, dataset_key, ctype, manipulations_dict, window_size, hop=hop, use_laplacian=use_laplacian, strat_sampling=strat_sampling, transform=transform) for dataset_key, img_list in datalist_dict.items() } joined_dataset = data.ConcatDataset([dataset for keys, dataset in datasets.items() ]) joined_generator = data.DataLoader(joined_dataset,**params,pin_memory=True) return joined_generator, joined_dataset # Generate a dictionary with "dataset": [dataset-video_id-frame_start] def get_img_list(img_path, datasets, ctype, split, window_size, hop, strat_sampling, balanced_minibatch, repeat_num=6): # Get the video_ids based on the split if split == 'train': with open('/research/cvl-guoxia11/deepfake_AIGC/FaceForensics/dataset/splits/train.json', 'r') as f_json: img_folders = json.load(f_json) elif split == 'val': with open('/research/cvl-guoxia11/deepfake_AIGC/FaceForensics/dataset/splits/val.json', 'r') as f_json: img_folders = json.load(f_json) elif split == 'test': with open('/research/cvl-guoxia11/deepfake_AIGC/FaceForensics/dataset/splits/test.json', 'r') as f_json: img_folders = json.load(f_json) data_dict = {} for dataset in datasets: data_list = [] data_filename = glob.glob(f'{img_path}/*{dataset}*{ctype}*.h5')[0] # Find the correct data file in the img_path f = h5py.File(data_filename, 'r') # Load the data file in f tmp_img_folders = [] if dataset == "original": tmp_img_folder = [x for sublist in img_folders for x in sublist] if split == 'train' and strat_sampling and balanced_minibatch: for i in range(4*repeat_num): tmp_img_folders.extend(tmp_img_folder) # Oversample by 4, then it has 2880 sequences. else: tmp_img_folders = tmp_img_folder else: _ = list(map(lambda x:["_".join([x[0],x[1]]),"_".join([x[1],x[0]])], img_folders)) tmp_img_folder = [x for sublist in _ for x in sublist] if split == 'train' and strat_sampling and balanced_minibatch: for i in range(repeat_num): tmp_img_folders.extend(tmp_img_folder) # Oversample by 4, then it has 2880 sequences. else: tmp_img_folders = tmp_img_folder for folder in tmp_img_folders: if strat_sampling: frame_limit = f[folder].shape[0] if frame_limit > window_size*hop: ## we record: the dataset name, the video id (folder) and total number of frames (frame_limit) data_list.append(f'{dataset}-{folder}-{frame_limit}') else: # Get the indices of the starting frame of each chunk of frames if f[folder].shape[0] > window_size*hop: frame_start_indices = np.arange(0, f[folder].shape[0]-(window_size*hop), window_size*hop) for frame_index in frame_start_indices: data_list.append(f'{dataset}-{folder}-{frame_index}') f.close() data_dict[dataset] = data_list return data_dict class ForensicFaceDatasetRNN(data.Dataset): def __init__(self, list_ids, img_path, dataset_name, ctype, manipulations_dict, window_size, hop, use_laplacian=False, strat_sampling=False, transform=[]): super(ForensicFaceDatasetRNN, self).__init__() self.list_ids = list_ids self.transform = transform self.use_laplacian = use_laplacian self.strat_sampling = strat_sampling self.dataset_name = dataset_name self.dname_to_id = manipulations_dict self.window_size = window_size self.hop = hop self.h5_handler = None self.data_filename = self.get_dbfile_path(f'{img_path}/*{dataset_name}*{ctype}*.h5') if not os.path.exists(self.data_filename): raise RunTimeError('%s not found' % (self.data_filename)) if self.hop < 1: raise ValueError(f'Minimum value of hop is 1. And you provided {self.hop}') def __len__(self): return len(self.list_ids) def get_dbfile_path(self,path_pattern): list_files = glob.glob(path_pattern) n_files = len(list_files) if n_files >=2: raise RuntimeError(f'Found multiple files in {path_pattern}') elif n_files == 0: raise RuntimeError(f'Files not found in {path_pattern}') else: return list_files[0] def __getitem__(self, index): if self.h5_handler is None: self.h5_handler = h5py.File(self.data_filename, 'r', swmr=True) file_id = self.list_ids[index].split('-') data_folder = file_id[1] if self.strat_sampling: frame_limit = file_id[2] ## now we random sample a frame within the video frame_id = np.random.randint(0,int(frame_limit)-(self.window_size*self.hop)) else: frame_id = file_id[2] frames = self.h5_handler[data_folder][int(frame_id):int(frame_id)+(self.window_size*self.hop):self.hop] ## Now handling the label label = 1.0 if self.dataset_name == "original" else 0.0 ''' ## visualization example: import cv2 print(f"the frames are: ", frames.shape) # output_frames = self.transform(frames) for _ in range(10): frame = frames[_] # print(f"the frame is: ", frame.shape) # print("output frames: ", frame.shape) image_data = frame.astype(np.uint8) image_data = cv2.cvtColor(image_data, cv2.COLOR_BGR2RGB) # cv2.imshow('demo.png', image_data) cv2.imwrite(f'demo_{_}_{self.dataset_name}.png', image_data) ''' frames = torch.stack(list(map(self.transform,frames))) image_names = '~'.join([f"{data_folder}/{int(frame_id) + i * self.hop}" for i in range(self.window_size)]) return frames, label, image_names ================================================ FILE: applications/deepfake_detection/sequence/runjobs_utils.py ================================================ # coding: utf-8 # author: Hierarchical Fine-Grained Image Forgery Detection and Localization import datetime import logging import sys import torch import os import datetime def init_logger(name): logger = logging.getLogger(name) h = logging.StreamHandler(sys.stdout) h.flush = sys.stdout.flush logger.addHandler(h) return logger logger = init_logger(__name__) logger.setLevel(logging.INFO) def torch_load_model(model, optimizer, load_model_path,strict=True): loaded_file = torch.load(load_model_path) model.load_state_dict(loaded_file['model_state_dict'], strict=strict) # model.load_state_dict(loaded_file['model_state_dict'], strict=False) iteration = loaded_file['iter'] scheduler = loaded_file['scheduler'] epoch = loaded_file['epoch'] val_loss = 1.0 if 'val_loss' in loaded_file: val_loss = loaded_file['val_loss'] # optimizer.load_state_dict(loaded_file['optimizer_state_dict']) return iteration, epoch, scheduler, val_loss class DataConfig(object): def __init__(self, model_path, model_name): self.model_path = model_path self.model_name = model_name class Saver(object): def __init__(self, model, optimizer, scheduler, data_config, starting_time, hours_limit=23, mins_limit=0): self.model = model self.optimizer = optimizer self.scheduler = scheduler self.best_val_loss = sys.maxsize self.data_config = data_config self.hours_limit = hours_limit self.mins_limit = mins_limit self.starting_time = starting_time def save_model(self,epoch,ib,val_loss,before_train,best_only=False,force_saving=False): # if (val_loss <= self.best_val_loss and not(before_train)) or force_saving: if val_loss <= self.best_val_loss or force_saving: ## preserving best_loss if val_loss <= self.best_val_loss: self.best_val_loss = val_loss if best_only: saving_list = [os.path.join(self.data_config.model_path,'best_model.pth')] if force_saving: saving_list = [os.path.join(self.data_config.model_path,'current_model.pth')] print("===================================") print(f"saving model list is: ", saving_list) print("===================================") for ss in saving_list: torch.save({'epoch': epoch, 'model_state_dict': self.model.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict() if self.optimizer is not None else None, 'iter' : ib, 'scheduler' : self.scheduler, 'val_loss' : val_loss, }, ss ) def check_time(self): this_time = datetime.datetime.now() days, hours, mins = self.days_hours_minutes( this_time - self.starting_time) return days, hours, mins def days_hours_minutes(self, td): return td.days, td.seconds//3600, (td.seconds//60) % 60 ================================================ FILE: applications/deepfake_detection/sequence/torch_utils.py ================================================ # coding: utf-8 # author: Hierarchical Fine-Grained Image Forgery Detection and Localization import torch import torch.nn as nn from tqdm import tqdm from sklearn import metrics import numpy as np from runjobs_utils import init_logger import logging import torch.nn.functional as F import os from collections import OrderedDict import csv logger = init_logger(__name__) logger.setLevel(logging.INFO) class ROC(object): def __init__(self): self.fpr = None self.tpr = None self.auc = None self.scores = None self.ap_0 = None self.ap_1 = None self.weighted_ap = None self.predictions = [] self.gt = [] self.best_acc = None def get_trunc_auc(self,fpr_value): abs_fpr = np.absolute(self.fpr - fpr_value) idx_min = np.argmin(abs_fpr) area_curve = sum(self.tpr[idx_min]) tot_area = sum(np.ones_like(self.tpr)[idx_min]) if tot_area == 0: raise ZeroDivisionError('when computing truncated ROC aread') t_auc = area_curve/tot_area return t_auc def get_tpr_at_fpr(self,fpr_value): abs_fpr = np.absolute(self.fpr - fpr_value) idx_min = np.argmin(abs_fpr) fpr_value_target = self.fpr[idx_min] idx = np.max(np.where(self.fpr == fpr_value_target)) return self.tpr[idx], self.scores[idx] def eval(self): self.fpr, self.tpr, self.scores = metrics.roc_curve(self.gt,self.predictions,drop_intermediate=True) self.auc = metrics.auc(self.fpr,self.tpr) def compute_best_accuracy(self,n_samples=200): '''find the best threshold for the accuracy.''' acc_thrs = [] min_thr = min(self.predictions) max_thr = max(self.predictions) all_thrs = np.linspace(min_thr,max_thr,n_samples).tolist() for t in all_thrs: acc = self.compute_acc(self.predictions,self.gt,t) acc_thrs.append((t,acc)) acc_thrs_arr = np.array(acc_thrs) idx_max = acc_thrs_arr[:,1].argmax() best_thr = acc_thrs_arr[idx_max,0] self.best_acc = acc_thrs_arr[idx_max,1] return best_thr, self.best_acc def compute_acc(self,list_scores,list_labels,thr): labels = np.array(list_labels) scores_th = (np.array(list_scores) >= thr).astype(np.int32) acc = (scores_th==labels).sum()/labels.size return acc def get_precision(self,criterion,thr): '''compute the best precision''' pred_labels = [] for d in self.predictions: if (d < thr): pred_labels.append(0) elif (d >= thr): pred_labels.append(1) self.ap_0 = metrics.precision_score(self.gt, pred_labels, average='binary', pos_label=0) self.ap_1 = metrics.precision_score(self.gt, pred_labels, average='binary', pos_label=1) self.weighted_ap = metrics.precision_score(self.gt, pred_labels, average='weighted') class Metrics(object): def __init__(self): self.tp = 0 self.tot_samples = 0 self.loss = 0.0 self.loss_samples = 0 self.roc = ROC() self.best_valid_acc = 0.0 self.best_valid_thr = 0.0 self.tuned_acc_thrs = (0,0) def update(self,tp,loss_value,samples): self.tp+=tp self.tot_samples+=samples self.loss+=loss_value self.loss_samples+=1 def get_avg_loss(self): if self.loss_samples == 0: raise ZeroDivisionError('not enough sample to avg loss') return self.loss/self.loss_samples def count_matching_samples(preds,true_labels,criterion,use_magic_loss=True): acc = 0 if use_magic_loss: for l,d in zip(true_labels,preds): if (l == criterion.class_label and d < criterion.R) \ or (l != criterion.class_label and d >= criterion.R): acc += 1 else: matching_idx = (preds.argmax(dim=1)==true_labels) acc = matching_idx.sum().item() return acc def eval_model(model,dataset_name,valid_joined_generator,criterion, device,desc='valid',val_metrics=None, debug_mode=False): model.eval() print(f"with the eval model and the debug mode {debug_mode}.") with torch.no_grad(): metrics = Metrics() for jb, val_batch in tqdm(enumerate(valid_joined_generator,1), total=len(valid_joined_generator), desc=desc): if jb % 8 != 0 and debug_mode: continue ## Getting Input val_img_batch_mmodal, val_true_labels, image_names = val_batch n_samples = val_img_batch_mmodal.shape[0] val_img_batch_mmodal = val_img_batch_mmodal.float().to(device) val_true_labels = val_true_labels.long().to(device) ## Inference val_preds = model(val_img_batch_mmodal) ## Computing loss val_loss = criterion(val_preds, val_true_labels) log_probs = F.softmax(val_preds, dim=-1) res_probs = torch.argmax(log_probs, dim=-1) fixed_labels = 1 - val_true_labels ## acc/matching_samples. matching_num = count_matching_samples(val_preds,val_true_labels,criterion,use_magic_loss=False) # metrics.roc.predictions.extend(res_probs.tolist()) metrics.roc.predictions.extend(log_probs[:,0].tolist()) ## Inverting the labels metrics.roc.gt.extend(fixed_labels[:].tolist()) metrics.update(matching_num,val_loss.item(),n_samples) ## Getting the Results metrics.roc.eval() print("the auc is: %.5f"%metrics.roc.auc) best_acc = best_thr = None best_thr, best_acc = metrics.roc.compute_best_accuracy() metrics.best_valid_acc = best_acc metrics.best_valid_thr = best_thr print("the accuracy is: %.5f: "%best_acc) print("the threshold is: %.5f: "%best_thr) fpr_values = [0.1,0.01] for fpr_value in fpr_values: tpr_fpr, score_for_tpr_fpr = metrics.roc.get_tpr_at_fpr(fpr_value) print('tpr_fpr_%.1f: '%(fpr_value*100.0), "%.5f"%tpr_fpr) ## Setting the model back to train mode model.train() return metrics def display_eval_tb(writer,metrics,tot_iter,desc='test',old_metrics=False): avg_loss = metrics.get_avg_loss() acc = metrics.roc.best_acc auc = metrics.roc.auc writer.add_scalar('%s/loss'%desc, avg_loss, tot_iter) writer.add_scalar('%s/acc'%desc, acc, tot_iter) writer.add_scalar('%s/auc'%desc, auc, tot_iter) fpr_values = [0.1,0.01] for fpr_value in fpr_values: tpr_fpr, score_for_tpr_fpr = metrics.roc.get_tpr_at_fpr(fpr_value) writer.add_scalar('%s/tpr_fpr_%.0f'%(desc,(fpr_value*100.0)), tpr_fpr, tot_iter) def train_logging(string, writer, logger, epoch, saver, tot_iter, loss, accu, lr_scheduler): _, hours, mins = saver.check_time() logger.info("[Epoch %d] | h:%d m:%d | iteration: %d, loss: %f, accu: %f", epoch, hours, mins, tot_iter, loss, accu) writer.add_scalar(string, loss, tot_iter ) for count, gp in enumerate(lr_scheduler.optimizer.param_groups,1): writer.add_scalar('progress/lr_%d'%count, gp['lr'], tot_iter) writer.add_scalar('progress/epoch', epoch, tot_iter) writer.add_scalar('progress/curr_patience',lr_scheduler.num_bad_epochs,tot_iter) writer.add_scalar('progress/patience',lr_scheduler.patience,tot_iter) class lrSched_monitor(object): """ This class is used to monitor the learning rate scheduler's behavior during training. If the learning rate decreases then this class re-initializes the last best state of the model and starts training from that point of time. Parameters ---------- model : torch model scheduler : learning rate scheduler object from training data_config : this object holds model_path and model_name, used to load the last best model. """ def __init__(self, model, scheduler, data_config): self.model = model self.scheduler = scheduler self.model_name = data_config.model_name self.model_path = data_config.model_path self._last_lr = [0]*len(scheduler.optimizer.param_groups) self.prev_lr_mean = self.get_lr_mean() ## Get the current mean learning rate from the optimizer def get_lr_mean(self): lr_mean = 0 for i, grp in enumerate(self.scheduler.optimizer.param_groups): if 'lr' in grp.keys(): lr_mean += grp['lr'] self._last_lr[i] = grp['lr'] return lr_mean/(i+1) ## This is the function that is to be called right after lr_scheduler.step(val_loss) def monitor(self): if self.scheduler.num_bad_epochs == self.scheduler.patience: self.prev_lr_mean = self.get_lr_mean() elif self.get_lr_mean() < self.prev_lr_mean: self.load_best_model() self.prev_lr_mean = self.get_lr_mean() ## This function loads the last best model once the learning rate decreases def load_best_model(self): device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") if torch.cuda.device_count() > 1: ckpt = torch.load(os.path.join(self.model_path,'best_model.pth')) self.model.load_state_dict(ckpt['model_state_dict'], strict=True) self.scheduler.optimizer.load_state_dict(ckpt['optimizer_state_dict']) else: print(f'Loading the best model from {self.model_path}') if device.type == 'cpu': ckpt = torch.load(os.path.join(self.model_path,'best_model.pth'), map_location='cpu') else: ckpt = torch.load(os.path.join(self.model_path,'best_model.pth')) ## Model State Dict state_dict = ckpt['model_state_dict'] ## Since the model files are saved on dataparallel we use the below hack to load the weights on a model in cpu or a model on single gpu. keys = state_dict.keys() values = state_dict.values() new_keys = [] for key in keys: new_key = key.replace('module.','') # remove the 'module.' new_keys.append(new_key) new_state_dict = OrderedDict(list(zip(new_keys, values))) # create a new OrderedDict with (key, value) pairs self.model.load_state_dict(new_state_dict, strict=True) ## Optimizer State Dict optim_state_dict = ckpt['optimizer_state_dict'] # Since the model files are saved on dataparallel we use the below hack to load the optimizer state in cpu or a model on single gpu. keys = optim_state_dict.keys() values = optim_state_dict.values() new_keys = [] for key in keys: new_key = key.replace('module.','') # remove the 'module.' new_keys.append(new_key) new_optim_state_dict = OrderedDict(list(zip(new_keys, values))) # create a new OrderedDict with (key, value) pairs self.scheduler.optimizer.load_state_dict(new_optim_state_dict) ## Reduce the learning rate for i, grp in enumerate(self.scheduler.optimizer.param_groups): grp['lr'] = self._last_lr[i] ================================================ FILE: applications/deepfake_detection/test.py ================================================ # coding: utf-8 # author: Hierarchical Fine-Grained Image Forgery Detection and Localization import os import numpy as np import subprocess import logging import sys import torch import torch.nn as nn import torch.nn.functional as F import argparse import datetime from tensorboardX import SummaryWriter from torch.optim.lr_scheduler import ReduceLROnPlateau source_path = os.path.join('./sequence') sys.path.append(source_path) from rnn_stratified_dataloader import get_dataloader from models.HiFiNet_deepfake import HiFiNet_deepfake from torch_utils import eval_model,display_eval_tb,train_logging,lrSched_monitor from runjobs_utils import init_logger,Saver,DataConfig,torch_load_model logger = init_logger(__name__) logger.setLevel(logging.INFO) starting_time = datetime.datetime.now() ## Deterministic training _seed_id = 100 torch.backends.cudnn.deterministic = True torch.manual_seed(_seed_id) datasets = ['original', 'Deepfakes', 'FaceSwap', 'NeuralTextures', 'Face2Face'] # datasets = ['original', 'Deepfakes'] manipulations_names = [n for c, n in enumerate(datasets) if n != 'original'] manipulations_dict = {n : c for c, n in enumerate(manipulations_names) } manipulations_dict['original'] = 255 for key, value in manipulations_dict.items(): print(key, value) ctype = 'c40' # Create the parser parser = argparse.ArgumentParser(description='Process some integers.') parser.add_argument('--batch_size', type=int, default=4, help='input batch size for training (default: 32)') parser.add_argument('--window_size', type=int, default=5, help='size of the sliding window (default: 5)') parser.add_argument('--dataset_name', type=str, default="FF++", help='size of the sliding window (default: 5)') parser.add_argument('--gpus', type=int, default=4, help='input batch size for training (default: 32)') parser.add_argument('--feat_dim', type=int, default=270, help='input dim to rnn. (default: 32)') parser.add_argument('--valid_epoch', type=int, default=2, help='val epoch') parser.add_argument('--display_step', type=int, default=50, help='display the loss value.') parser.add_argument('--learning_rate', type=float, default=1e-3, help='the used learning rate') # Parse the arguments args = parser.parse_args() ## Hyper-params ####################### hparams = { 'epochs': 50, 'batch_size': args.batch_size, 'basic_lr': args.learning_rate, 'fine_tune': True, 'use_laplacian': True, 'step_factor': 0.1, 'patience': 20, 'weight_decay': 1e-06, 'lr_gamma': 2.0, 'use_magic_loss': True, 'feat_dim': args.feat_dim, 'drop_rate': 0.2, 'skip_valid': False, 'rnn_type': 'LSTM', 'rnn_hidden_size': 256, 'num_rnn_layers': 1, 'rnn_drop_rate': 0.2, 'bidir': False, 'merge_mode': 'concat', 'perc_margin_1': 0.95, 'perc_margin_2': 0.95, 'soft_boundary': False, 'dist_p': 2, 'radius_param': 0.84, 'strat_sampling': True, 'normalize': True, 'window_size': args.window_size, 'hop': 1, 'valid_epoch': args.valid_epoch, 'display_step': args.display_step, 'use_sched_monitor': True } batch_size = hparams['batch_size'] basic_lr = hparams['basic_lr'] fine_tune = hparams['fine_tune'] use_laplacian = hparams['use_laplacian'] step_factor = hparams['step_factor'] patience = hparams['patience'] weight_decay = hparams['weight_decay'] lr_gamma = hparams['lr_gamma'] use_magic_loss = hparams['use_magic_loss'] feat_dim = hparams['feat_dim'] drop_rate = hparams['drop_rate'] rnn_type = hparams['rnn_type'] rnn_hidden_size = hparams['rnn_hidden_size'] num_rnn_layers = hparams['num_rnn_layers'] rnn_drop_rate = hparams['rnn_drop_rate'] bidir = hparams['bidir'] merge_mode = hparams['merge_mode'] perc_margin_1 = hparams['perc_margin_1'] perc_margin_2 = hparams['perc_margin_2'] dist_p = hparams['dist_p'] radius_param = hparams['radius_param'] strat_sampling = hparams['strat_sampling'] normalize = hparams['normalize'] window_size = hparams['window_size'] hop = hparams['hop'] soft_boundary = hparams['soft_boundary'] use_sched_monitor = hparams['use_sched_monitor'] ######################################## workers_per_gpu = 6 dataset_name = f"{args.dataset_name}" exp_name = f"exp_FF_c40_bs_{batch_size}_lr_{basic_lr}_ws_{window_size}" model_name = exp_name model_path = os.path.join(f'./{dataset_name}', model_name) print(f'Window_size: {args.window_size}; Dataset: {dataset_name}; Batch_Size: {batch_size}; LR: {basic_lr}.') print(f"the model path is: ", model_path) ## Data Generation img_path = "/user/guoxia11/cvlshare/cvl-guoxia11/FaceForensics_HiFiNet" balanced_minibatch_opt = True if dataset_name == 'FF++': train_generator, train_dataset = get_dataloader( img_path, datasets, ctype, manipulations_dict, window_size, hop, use_laplacian, normalize, strat_sampling, balanced_minibatch_opt, 'train', batch_size, workers=workers_per_gpu*args.gpus ) test_generator, test_dataset = get_dataloader( img_path, datasets, ctype, manipulations_dict, window_size, hop, use_laplacian, normalize, strat_sampling, False, 'test', batch_size, workers=workers_per_gpu*args.gpus ) del train_dataset del test_dataset elif dataset_name == "CelebDF": pass ## TODO: will be released in the near future. elif dataset_name == 'DFW': pass ## TODO: will be released in the near future. ## Model definition device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") model = HiFiNet_deepfake(use_laplacian=True, drop_rate=drop_rate, use_magic_loss=False, pretrained=True, rnn_drop_rate=rnn_drop_rate, feat_dim=feat_dim, rnn_hidden_size=rnn_hidden_size, num_rnn_layers=num_rnn_layers, bidir=bidir) model = model.to(device) model = torch.nn.DataParallel(model).cuda() ## Fine-tuning functions params_to_optimize = model.parameters() optimizer = torch.optim.Adam(params_to_optimize, lr=basic_lr, weight_decay=weight_decay) lr_scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=step_factor, min_lr=1e-06, patience=patience, verbose=True) criterion = nn.CrossEntropyLoss() ## Re-loading the model in case epoch_init=epoch=ib=ib_off=before_train=0 # load_model_path = os.path.join(model_path,'best_model.pth') Not as good as the current_model.pth load_model_path = os.path.join(model_path,'current_model.pth') val_loss = np.inf if os.path.exists(load_model_path): logger.info(f'Loading weights, optimizer and scheduler from {load_model_path}...') ib_off, epoch_init, scheduler, val_loss = torch_load_model(model, optimizer, load_model_path) ## Saver object and data config data_config = DataConfig(model_path, model_name) sched_monitor = lrSched_monitor(model, lr_scheduler, data_config) ## Start testing metrics = eval_model(model,dataset_name,test_generator,criterion,device,desc='valid',val_metrics=None,debug_mode=False) ================================================ FILE: applications/deepfake_detection/test.sh ================================================ source ~/.bashrc conda activate HiFi_Net_deepfake CUDA_NUM="0,1,3,4,5,6,7" CUDA_VISIBLE_DEVICES=$CUDA_NUM python test.py \ --dataset_name FF++ \ --batch_size 32 \ --window_size 10 \ --gpus 7 \ --valid_epoch 1 \ --feat_dim 1000 \ --learning_rate 1e-4 \ --display_step 150 ================================================ FILE: applications/deepfake_detection/train.py ================================================ # coding: utf-8 # author: Hierarchical Fine-Grained Image Forgery Detection and Localization import os import numpy as np import subprocess import logging import sys import torch import torch.nn as nn import torch.nn.functional as F import argparse import datetime from tensorboardX import SummaryWriter from torch.optim.lr_scheduler import ReduceLROnPlateau source_path = os.path.join('./sequence') sys.path.append(source_path) from rnn_stratified_dataloader import get_dataloader from models.HiFiNet_deepfake import HiFiNet_deepfake from torch_utils import eval_model,display_eval_tb,train_logging,lrSched_monitor from runjobs_utils import init_logger,Saver,DataConfig,torch_load_model logger = init_logger(__name__) logger.setLevel(logging.INFO) starting_time = datetime.datetime.now() ## Deterministic training _seed_id = 100 torch.backends.cudnn.deterministic = True torch.manual_seed(_seed_id) datasets = ['original', 'Deepfakes', 'FaceSwap', 'NeuralTextures', 'Face2Face'] # datasets = ['original', 'Deepfakes'] manipulations_names = [n for c, n in enumerate(datasets) if n != 'original'] manipulations_dict = {n : c for c, n in enumerate(manipulations_names) } manipulations_dict['original'] = 255 for key, value in manipulations_dict.items(): print(key, value) ctype = 'c40' # Create the parser parser = argparse.ArgumentParser(description='Process some integers.') parser.add_argument('--batch_size', type=int, default=4, help='input batch size for training (default: 32)') parser.add_argument('--window_size', type=int, default=5, help='size of the sliding window (default: 5)') parser.add_argument('--dataset_name', type=str, default="FF++", help='size of the sliding window (default: 5)') parser.add_argument('--gpus', type=int, default=4, help='input batch size for training (default: 32)') parser.add_argument('--feat_dim', type=int, default=270, help='input dim to rnn. (default: 32)') parser.add_argument('--valid_epoch', type=int, default=2, help='val epoch') parser.add_argument('--display_step', type=int, default=50, help='display the loss value.') parser.add_argument('--learning_rate', type=float, default=1e-3, help='the used learning rate') # Parse the arguments args = parser.parse_args() ## Hyper-params ####################### hparams = { 'epochs': 50, 'batch_size': args.batch_size, 'basic_lr': args.learning_rate, 'fine_tune': True, 'use_laplacian': True, 'step_factor': 0.1, 'patience': 20, 'weight_decay': 1e-06, 'lr_gamma': 2.0, 'use_magic_loss': True, 'feat_dim': args.feat_dim, 'drop_rate': 0.2, 'skip_valid': False, 'rnn_type': 'LSTM', 'rnn_hidden_size': 256, 'num_rnn_layers': 1, 'rnn_drop_rate': 0.2, 'bidir': False, 'merge_mode': 'concat', 'perc_margin_1': 0.95, 'perc_margin_2': 0.95, 'soft_boundary': False, 'dist_p': 2, 'radius_param': 0.84, 'strat_sampling': True, 'normalize': True, 'window_size': args.window_size, 'hop': 1, 'valid_epoch': args.valid_epoch, 'display_step': args.display_step, 'use_sched_monitor': True } batch_size = hparams['batch_size'] basic_lr = hparams['basic_lr'] fine_tune = hparams['fine_tune'] use_laplacian = hparams['use_laplacian'] step_factor = hparams['step_factor'] patience = hparams['patience'] weight_decay = hparams['weight_decay'] lr_gamma = hparams['lr_gamma'] use_magic_loss = hparams['use_magic_loss'] feat_dim = hparams['feat_dim'] drop_rate = hparams['drop_rate'] rnn_type = hparams['rnn_type'] rnn_hidden_size = hparams['rnn_hidden_size'] num_rnn_layers = hparams['num_rnn_layers'] rnn_drop_rate = hparams['rnn_drop_rate'] bidir = hparams['bidir'] merge_mode = hparams['merge_mode'] perc_margin_1 = hparams['perc_margin_1'] perc_margin_2 = hparams['perc_margin_2'] dist_p = hparams['dist_p'] radius_param = hparams['radius_param'] strat_sampling = hparams['strat_sampling'] normalize = hparams['normalize'] window_size = hparams['window_size'] hop = hparams['hop'] soft_boundary = hparams['soft_boundary'] use_sched_monitor = hparams['use_sched_monitor'] ######################################## workers_per_gpu = 6 dataset_name = f"{args.dataset_name}" exp_name = f"exp_FF_c40_bs_{batch_size}_lr_{basic_lr}_ws_{window_size}" model_name = exp_name model_path = os.path.join(f'./{dataset_name}', model_name) print(f'Window_size: {args.window_size}; Dataset: {dataset_name}; Batch_Size: {batch_size}; LR: {basic_lr}.') os.makedirs('./log', exist_ok=True) log_file_path = f"log/{exp_name}.txt" with open(log_file_path, "a+") as log_file: log_file.write( f'Dataset Name: {dataset_name} \n' f'Window_size: {args.window_size}' ) # Create the model path if doesn't exists if not os.path.exists(model_path): subprocess.call(f"mkdir -p {model_path}", shell=True) ## Data Generation img_path = "/user/guoxia11/cvlshare/cvl-guoxia11/FaceForensics_HiFiNet" balanced_minibatch_opt = True if dataset_name == 'FF++': train_generator, train_dataset = get_dataloader( img_path, datasets, ctype, manipulations_dict, window_size, hop, use_laplacian, normalize, strat_sampling, balanced_minibatch_opt, 'train', batch_size, workers=workers_per_gpu*args.gpus ) test_generator, test_dataset = get_dataloader( img_path, datasets, ctype, manipulations_dict, window_size, hop, use_laplacian, normalize, strat_sampling, False, 'test', batch_size, workers=workers_per_gpu*args.gpus ) # print("the dataset length is: ", len(train_dataset)) # print("the dataloader length is: ", len(train_generator)) del train_dataset del test_dataset elif dataset_name == "CelebDF": pass ## TODO: will be released in the near future. elif dataset_name == 'DFW': pass ## TODO: will be released in the near future. ## Model definition device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = HiFiNet_deepfake(use_laplacian=True, drop_rate=drop_rate, use_magic_loss=False, pretrained=True, rnn_drop_rate=rnn_drop_rate, feat_dim=feat_dim, rnn_hidden_size=rnn_hidden_size, num_rnn_layers=num_rnn_layers, bidir=bidir) model = model.to(device) model = torch.nn.DataParallel(model).cuda() ## Fine-tuning functions params_to_optimize = model.parameters() optimizer = torch.optim.Adam(params_to_optimize, lr=basic_lr, weight_decay=weight_decay) lr_scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=step_factor, min_lr=1e-09, patience=patience, verbose=True) criterion = nn.CrossEntropyLoss() ## Re-loading the model in case epoch_init=epoch=ib=ib_off=before_train=0 load_model_path = os.path.join(model_path,'current_model.pth') val_loss = np.inf if os.path.exists(load_model_path): logger.info(f'Loading weights, optimizer and scheduler from {load_model_path}...') _, _, _, _ = torch_load_model(model, optimizer, load_model_path) ## Saver object and data config data_config = DataConfig(model_path, model_name) saver = Saver(model, optimizer, lr_scheduler, data_config, starting_time, hours_limit=23, mins_limit=0) sched_monitor = lrSched_monitor(model, lr_scheduler, data_config) ## Writer summary for tb tb_folder = os.path.join(model_path, 'tb_logs',model_name) writer = SummaryWriter(tb_folder) log_string_config = ' '.join([k+':'+str(v) for k,v in hparams.items()]) writer.add_text('config : %s' % model_name, log_string_config, 0) if epoch_init == 0: model.zero_grad() ## Start training tot_iter = 0 total_loss = 0 total_accu = 0 for epoch in range(epoch_init,hparams['epochs']): logger.info(f'Epoch ############: {epoch}') for ib, (img_batch_mmodal, true_labels, manip_type) in enumerate(train_generator,1): img_batch = img_batch_mmodal.float().to(device) true_labels = true_labels.long().to(device) optimizer.zero_grad() pred_labels = model(img_batch) loss = criterion(pred_labels, true_labels) total_loss += loss.item() log_probs = F.softmax(pred_labels, dim=-1) res_probs = torch.argmax(log_probs, dim=-1) summation = torch.sum(res_probs == true_labels) accu = summation / img_batch.shape[0] total_accu += accu loss.backward() optimizer.step() tot_iter += 1 if tot_iter % hparams['display_step'] == 0: train_logging( 'loss/train_loss_iter', writer, logger, epoch, saver, tot_iter, total_loss/hparams['display_step'], total_accu/hparams['display_step'], lr_scheduler ) with open(log_file_path, "a+") as log_file: log_file.write( f"Epoch: {epoch}, Iteration: {tot_iter}, " f"Train Loss: {total_loss/hparams['display_step']:.4f}, " f"Accuracy: {total_accu/hparams['display_step']:.4f}\n" ) total_loss = 0 total_accu = 0 saver.save_model(epoch,tot_iter,sys.maxsize,before_train,force_saving=True) if (epoch % hparams['valid_epoch'] == 0) or (epoch == hparams['epochs']): metrics = eval_model(model,dataset_name,test_generator,criterion,device,desc='valid',val_metrics=None,debug_mode=False) # metrics = eval_model(model,dataset_name,test_generator,criterion,device,desc='valid',val_metrics=None,debug_mode=True) val_loss = metrics.get_avg_loss() saver.save_model(epoch,ib+ib_off,val_loss,before_train,best_only=True) # display_eval_tb(writer,metrics,epoch,desc='valid') display_eval_tb(writer,metrics,epoch,desc='test') lr_scheduler.step(val_loss) sched_monitor.monitor() for i, grp in enumerate(sched_monitor.scheduler.optimizer.param_groups): if 'lr' in grp.keys(): print("the first grp learning rate is: ", grp['lr']) break file_path = f"./{exp_name}.txt" os.makedirs(os.path.dirname(file_path), exist_ok=True) with open(file_path, 'a') as f: f.write(f"AUC: {metrics.roc.auc}\n") f.write(f"Best Accuracy: {metrics.best_valid_acc} (Threshold: {metrics.best_valid_thr})\n") for fpr_value in [0.1, 0.01]: tpr_fpr, score_for_tpr_fpr = metrics.roc.get_tpr_at_fpr(fpr_value) f.write(f"TPR at FPR={fpr_value*100}%: {tpr_fpr} (Score: {score_for_tpr_fpr})\n") f.write(f"Average Loss: {metrics.get_avg_loss()}\n") f.write("#" * 100) ================================================ FILE: applications/deepfake_detection/train.sh ================================================ source ~/.bashrc conda activate HiFi_Net_deepfake CUDA_NUM=0,1,3,4,5,6 CUDA_VISIBLE_DEVICES=$CUDA_NUM python train.py \ --dataset_name FF++ \ --batch_size 32 \ --window_size 10 \ --gpus 6 \ --valid_epoch 1 \ --feat_dim 1000 \ --learning_rate 1e-4 \ --display_step 150 ================================================ FILE: data_dir/CASIA/CASIA1/fake.txt ================================================ Sp_D_CND_A_pla0005_pla0023_0281.jpg Sp_D_CND_A_sec0056_sec0015_0282.jpg Sp_D_CNN_A_ani0049_ani0084_0266.jpg ================================================ FILE: data_dir/CASIA/CASIA2/fake.txt ================================================ Tp_D_CND_M_N_ani00018_sec00096_00138.tif Tp_D_CND_M_N_art00076_art00077_10289.tif Tp_D_CND_M_N_art00077_art00076_10290.tif ================================================ FILE: data_dir/Coverage/fake.txt ================================================ 10t.tif 11t.tif 12t.tif 13t.tif 14t.tif 15t.tif 16t.tif 17t.tif 18t.tif 19t.tif 1t.tif ================================================ FILE: data_dir/IMD2020/fake.txt ================================================ 00010_fake_01.jpg ================================================ FILE: data_dir/NIST16/alllist.txt ================================================ probe/NC2016_0016.jpg mask/mani_NC2016_0940.png probe/NC2016_0128.jpg mask/mani_NC2016_3942.png probe/NC2016_0130.jpg mask/mani_NC2016_6409.png ================================================ FILE: data_dir/columbia/vallist.txt ================================================ canong3_canonxt_sub_01.tif canong3_canonxt_sub_02.tif canong3_canonxt_sub_03.tif canong3_canonxt_sub_04.tif canong3_canonxt_sub_05.tif canong3_canonxt_sub_06.tif canong3_canonxt_sub_07.tif canong3_canonxt_sub_08.tif canong3_canonxt_sub_09.tif ================================================ FILE: environment.yml ================================================ name: HiFi_Net channels: - conda-forge - pytorch - defaults dependencies: - _libgcc_mutex=0.1=main - _openmp_mutex=5.1=1_gnu - absl-py=1.3.0=py37h06a4308_0 - aiohttp=3.8.3=py37h5eee18b_0 - aiosignal=1.2.0=pyhd3eb1b0_0 - async-timeout=4.0.2=py37h06a4308_0 - asynctest=0.13.0=py_0 - attrs=22.1.0=py37h06a4308_0 - blas=1.0=mkl - blinker=1.4=py37h06a4308_0 - brotlipy=0.7.0=py37h27cfd23_1003 - bzip2=1.0.8=h7b6447c_0 - c-ares=1.19.1=h5eee18b_0 - ca-certificates=2023.12.12=h06a4308_0 - cachetools=4.2.2=pyhd3eb1b0_0 - certifi=2022.12.7=py37h06a4308_0 - cffi=1.15.1=py37h5eee18b_3 - charset-normalizer=2.0.4=pyhd3eb1b0_0 - click=8.0.4=py37h06a4308_0 - cryptography=39.0.1=py37h9ce1e76_0 - cudatoolkit=11.3.1=h2bc3f7f_2 - cycler=0.11.0=pyhd3eb1b0_0 - ffmpeg=4.3=hf484d3e_0 - fftw=3.3.9=h27cfd23_1 - freetype=2.12.1=h4a9f257_0 - frozenlist=1.3.3=py37h5eee18b_0 - giflib=5.2.1=h5eee18b_3 - gmp=6.2.1=h295c915_3 - gnutls=3.6.15=he1e5248_0 - google-auth=2.6.0=pyhd3eb1b0_0 - google-auth-oauthlib=0.4.4=pyhd3eb1b0_0 - grpcio=1.42.0=py37hce63b2e_0 - icu=67.1=he1b5a44_0 - idna=3.4=py37h06a4308_0 - imageio=2.9.0=pyhd3eb1b0_0 - importlib-metadata=4.11.3=py37h06a4308_0 - intel-openmp=2021.4.0=h06a4308_3561 - joblib=1.1.0=pyhd3eb1b0_0 - jpeg=9e=h5eee18b_1 - kiwisolver=1.4.4=py37h6a678d5_0 - lame=3.100=h7b6447c_0 - lcms2=2.12=h3be6417_0 - ld_impl_linux-64=2.38=h1181459_1 - lerc=3.0=h295c915_0 - libblas=3.9.0=12_linux64_mkl - libcblas=3.9.0=12_linux64_mkl - libdeflate=1.17=h5eee18b_1 - libffi=3.4.4=h6a678d5_0 - libgcc-ng=11.2.0=h1234567_1 - libgfortran-ng=11.2.0=h00389a5_1 - libgfortran5=11.2.0=h1234567_1 - libgomp=11.2.0=h1234567_1 - libiconv=1.16=h7f8727e_2 - libidn2=2.3.4=h5eee18b_0 - libpng=1.6.39=h5eee18b_0 - libprotobuf=3.20.3=he621ea3_0 - libstdcxx-ng=11.2.0=h1234567_1 - libtasn1=4.19.0=h5eee18b_0 - libtiff=4.5.1=h6a678d5_0 - libunistring=0.9.10=h27cfd23_0 - libuv=1.44.2=h5eee18b_0 - libwebp=1.2.4=h11a3e52_1 - libwebp-base=1.2.4=h5eee18b_1 - lz4-c=1.9.4=h6a678d5_0 - markdown=3.4.1=py37h06a4308_0 - markupsafe=2.1.1=py37h7f8727e_0 - matplotlib=3.2.2=1 - matplotlib-base=3.2.2=py37h1d35a4c_1 - mkl=2021.4.0=h06a4308_640 - mkl-service=2.4.0=py37h7f8727e_0 - mkl_fft=1.3.1=py37hd3c417c_0 - mkl_random=1.2.2=py37h51133e4_0 - multidict=6.0.2=py37h5eee18b_0 - ncurses=6.4=h6a678d5_0 - nettle=3.7.3=hbbd107a_1 - numpy=1.21.5=py37h6c91a56_3 - numpy-base=1.21.5=py37ha15fc14_3 - oauthlib=3.2.1=py37h06a4308_0 - openh264=2.1.1=h4ff587b_0 - openssl=1.1.1w=h7f8727e_0 - pillow=9.4.0=py37h6a678d5_0 - pip=23.3.2=pyhd8ed1ab_0 - protobuf=3.20.3=py37h6a678d5_0 - pyasn1=0.4.8=pyhd3eb1b0_0 - pyasn1-modules=0.2.8=py_0 - pycparser=2.21=pyhd3eb1b0_0 - pyjwt=2.4.0=py37h06a4308_0 - pyopenssl=23.0.0=py37h06a4308_0 - pyparsing=3.0.9=py37h06a4308_0 - pysocks=1.7.1=py37_1 - python=3.7.16=h7a1cb2a_0 - python-dateutil=2.8.2=pyhd3eb1b0_0 - python_abi=3.7=2_cp37m - pytorch=1.11.0=py3.7_cuda11.3_cudnn8.2.0_0 - pytorch-mutex=1.0=cuda - pyyaml=6.0=py37h5eee18b_1 - readline=8.2=h5eee18b_0 - requests=2.28.1=py37h06a4308_0 - requests-oauthlib=1.3.0=py_0 - rsa=4.7.2=pyhd3eb1b0_1 - scikit-learn=1.0.2=py37hf9e9bfc_0 - scipy=1.7.3=py37h6c91a56_2 - setuptools=68.2.2=pyhd8ed1ab_0 - six=1.16.0=pyhd3eb1b0_1 - sqlite=3.41.2=h5eee18b_0 - tensorboard=2.10.0=py37h06a4308_0 - tensorboard-data-server=0.6.1=py37h52d8a92_0 - tensorboard-plugin-wit=1.8.1=py37h06a4308_0 - threadpoolctl=2.2.0=pyh0d69192_0 - tk=8.6.12=h1ccaba5_0 - torchvision=0.12.0=py37_cu113 - tornado=5.1.1=py37h7b6447c_0 - tqdm=4.64.1=py37h06a4308_0 - typing-extensions=4.3.0=py37h06a4308_0 - typing_extensions=4.3.0=py37h06a4308_0 - urllib3=1.26.14=py37h06a4308_0 - werkzeug=2.2.2=py37h06a4308_0 - wheel=0.38.4=py37h06a4308_0 - xz=5.4.5=h5eee18b_0 - yacs=0.1.6=pyhd3eb1b0_1 - yaml=0.2.5=h7b6447c_0 - yarl=1.8.1=py37h5eee18b_0 - zipp=3.11.0=py37h06a4308_0 - zlib=1.2.13=h5eee18b_0 - zstd=1.5.5=hc292b87_0 - pip: - einops==0.6.1 - kmeans-pytorch==0.3 - opencv-python==4.8.1.78 prefix: /home/aya/.conda/envs/HiFi_Net ================================================ FILE: models/GaussianSmoothing.py ================================================ # ------------------------------------------------------------------------------ # Author: Xiao Guo (guoxia11@msu.edu) # CVPR2023: Hierarchical Fine-Grained Image Forgery Detection and Localization # ------------------------------------------------------------------------------ import math import numbers import torch from torch import nn from torch.nn import functional as F class GaussianSmoothing(nn.Module): """ Apply gaussian smoothing on a 1d, 2d or 3d tensor. Filtering is performed seperately for each channel in the input using a depthwise convolution. Arguments: channels (int, sequence): Number of channels of the input tensors. Output will have this number of channels as well. kernel_size (int, sequence): Size of the gaussian kernel. sigma (float, sequence): Standard deviation of the gaussian kernel. dim (int, optional): The number of dimensions of the data. Default value is 2 (spatial). """ def __init__(self, channels, kernel_size, sigma, dim=2): super(GaussianSmoothing, self).__init__() if isinstance(kernel_size, numbers.Number): kernel_size = [kernel_size] * dim if isinstance(sigma, numbers.Number): sigma = [sigma] * dim # The gaussian kernel is the product of the # gaussian function of each dimension. kernel = 1 meshgrids = torch.meshgrid( [ torch.arange(size, dtype=torch.float32) for size in kernel_size ], indexing='ij' ) for size, std, mgrid in zip(kernel_size, sigma, meshgrids): mean = (size - 1) / 2 kernel *= 1 / (std * math.sqrt(2 * math.pi)) * \ torch.exp(-((mgrid - mean) / std) ** 2 / 2) # Make sure sum of values in gaussian kernel equals 1. kernel = kernel / torch.sum(kernel) # Reshape to depthwise convolutional weight kernel = kernel.view(1, 1, *kernel.size()) kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1)) self.register_buffer('weight', kernel) self.groups = channels if dim == 1: self.conv = F.conv1d elif dim == 2: self.conv = F.conv2d elif dim == 3: self.conv = F.conv3d else: raise RuntimeError( 'Only 1, 2 and 3 dimensions are supported. Received {}.'.format(dim) ) def forward(self, input): """ Apply gaussian filter to input. Arguments: input (torch.Tensor): Input to apply gaussian filter on. Returns: filtered (torch.Tensor): Filtered output. """ return self.conv(input, weight=self.weight, groups=self.groups) ================================================ FILE: models/LaPlacianMs.py ================================================ # ------------------------------------------------------------------------------ # Author: Xiao Guo (guoxia11@msu.edu) # CVPR2023: Hierarchical Fine-Grained Image Forgery Detection and Localization # ------------------------------------------------------------------------------ import torch import torch.nn as nn from torch.nn import functional as F from .GaussianSmoothing import GaussianSmoothing class LaPlacianMs(nn.Module): def __init__(self,in_c,gauss_ker_size=3,scale=[2],drop_rate=0.2): super(LaPlacianMs, self).__init__() self.scale = scale self.gauss_ker_size = gauss_ker_size ## apply gaussian smoothing to input feature maps with 3 planes ## with kernel size K and sigma s self.smoothing = nn.ModuleDict() for s in self.scale: self.smoothing['scale-'+str(s)] = GaussianSmoothing(in_c, self.gauss_ker_size, s) self.conv_1x1 = nn.Sequential(nn.Conv2d(in_c*len(scale), in_c, kernel_size=1, stride=1, bias=False,groups=1), nn.BatchNorm2d(in_c), nn.ReLU(inplace=True), nn.Dropout(p=drop_rate) ) # Official init from torch repo. for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight) elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) elif isinstance(m, nn.Linear): nn.init.constant_(m.bias, 0) def down(self,x,s): return F.interpolate(x,scale_factor=s, mode='bilinear', align_corners=False) def up (self,x, size): return F.interpolate(x,size=size,mode='bilinear',align_corners=False) def forward(self, x): for i, s in enumerate(self.scale): sm = self.smoothing['scale-'+str(s)](x) sm = self.down(sm,1/s) sm = self.up(sm,(x.shape[2],x.shape[3])) if i == 0: diff = x - sm else: diff = torch.cat((diff, x - sm), dim=1) return self.conv_1x1(diff) ================================================ FILE: models/NLCDetection_api.py ================================================ # ------------------------------------------------------------------------------ # Author: Xiao Guo (guoxia11@msu.edu) # CVPR2023: Hierarchical Fine-Grained Image Forgery Detection and Localization # ------------------------------------------------------------------------------ import torch import torch.nn as nn import torch.nn.functional as F from models.seg_hrnet_config import get_cfg_defaults import time def weights_init(init_type='gaussian'): def init_fun(m): classname = m.__class__.__name__ if (classname.find('Conv') == 0 or classname.find( 'Linear') == 0) and hasattr(m, 'weight'): if init_type == 'gaussian': nn.init.normal_(m.weight, 0.0, 0.02) elif init_type == 'xavier': nn.init.xavier_normal_(m.weight, gain=math.sqrt(2)) elif init_type == 'kaiming': nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in') elif init_type == 'orthogonal': nn.init.orthogonal_(m.weight, gain=math.sqrt(2)) elif init_type == 'default': pass else: assert 0, "Unsupported initialization: {}".format(init_type) if hasattr(m, 'bias') and m.bias is not None: nn.init.constant_(m.bias, 0.0) return init_fun class PartialConv(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True): super().__init__() self.input_conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias) self.mask_conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, False) self.input_conv.apply(weights_init('kaiming')) torch.nn.init.constant_(self.mask_conv.weight, 1.0) # mask is not updated for param in self.mask_conv.parameters(): param.requires_grad = False def forward(self, input, mask): # http://masc.cs.gmu.edu/wiki/partialconv # C(X) = W^T * X + b, C(0) = b, D(M) = 1 * M + 0 = sum(M) # W^T* (M .* X) / sum(M) + b = [C(M .* X) – C(0)] / D(M) + C(0) ## GX: masking the input outside function. output = self.input_conv(input) if self.input_conv.bias is not None: output_bias = self.input_conv.bias.view(1, -1, 1, 1).expand_as(output) else: output_bias = torch.zeros_like(output) with torch.no_grad(): output_mask = self.mask_conv(mask) no_update_holes = output_mask == 0 ## in output_mask, fills the 0-value-position with 1.0 ## without this step, math error occurs. mask_sum = output_mask.masked_fill_(no_update_holes, 1.0) output_pre = (output - output_bias) / mask_sum + output_bias output = output_pre.masked_fill_(no_update_holes, 0.0) new_mask = torch.ones_like(output) new_mask = new_mask.masked_fill_(no_update_holes, 0.0) return output, new_mask class NonLocalMask(nn.Module): def __init__(self, in_channels, reduce_scale): super(NonLocalMask, self).__init__() self.r = reduce_scale # input channel number self.ic = in_channels * self.r * self.r # middle channel number self.mc = self.ic self.g = nn.Conv2d(in_channels=self.ic, out_channels=self.ic, kernel_size=1, stride=1, padding=0) self.theta = nn.Conv2d(in_channels=self.ic, out_channels=self.mc, kernel_size=1, stride=1, padding=0) self.phi = nn.Conv2d(in_channels=self.ic, out_channels=self.mc, kernel_size=1, stride=1, padding=0) self.W_s = nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=1, stride=1, padding=0) self.gamma_s = nn.Parameter(torch.ones(1)) self.getmask = nn.Sequential( nn.Conv2d(in_channels=in_channels, out_channels=16, kernel_size=3, stride=1, padding=1), nn.ReLU(), nn.Conv2d(in_channels=16, out_channels=1, kernel_size=3, stride=1, padding=1) ) ## Pconv self.Pconv_1 = PartialConv(3, 3, kernel_size=3, stride=2) self.Pconv_2 = PartialConv(3, 3, kernel_size=3, stride=2) self.Pconv_3 = PartialConv(3, 1, kernel_size=3, stride=2) def forward(self, x, img): b, c, h, w = x.shape x1 = x.reshape(b, self.ic, h // self.r, w // self.r) # g x g_x = self.g(x1).view(b, self.ic, -1) g_x = g_x.permute(0, 2, 1) # theta theta_x = self.theta(x1).view(b, self.mc, -1) theta_x_s = theta_x.permute(0, 2, 1) # phi x phi_x = self.phi(x1).view(b, self.mc, -1) phi_x_s = phi_x # non-local attention f_s = torch.matmul(theta_x_s, phi_x_s) f_s_div = F.softmax(f_s, dim=-1) # get y_s y_s = torch.matmul(f_s_div, g_x) y_s = y_s.permute(0, 2, 1).contiguous() y_s = y_s.view(b, c, h, w) # GX: (256,256,18), output mask for the deep metric loss. mask_feat = x + self.gamma_s * self.W_s(y_s) # get 1-dimensional mask_tmp mask_binary = torch.sigmoid(self.getmask(mask_feat)) mask_tmp = mask_binary.repeat(1, 3, 1, 1) mask_img = img * mask_tmp # mask_img is the overlaid image. ## conv output x, new_mask = self.Pconv_1(mask_img, mask_tmp) x, new_mask = self.Pconv_2(x, new_mask) x, _ = self.Pconv_3(x, new_mask) mask_binary = mask_binary.squeeze(dim=1) return x, mask_feat, mask_binary class Flatten(nn.Module): def __init__(self): super(Flatten, self).__init__() def forward(self, x): return x.view(x.size(0), -1) class Classifer(nn.Module): def __init__(self, in_channels, output_channels): super(Classifer, self).__init__() self.pool = nn.Sequential( # nn.AdaptiveAvgPool2d((1,1)), nn.AdaptiveAvgPool2d(1), Flatten() ) self.fc = nn.Linear(in_channels, output_channels, bias=True) self.relu = nn.ReLU(inplace=True) def forward(self, x): feat = self.pool(x) feat = self.relu(feat) cls_res = self.fc(feat) return cls_res class BranchCLS(nn.Module): def __init__(self, in_channels, output_channels): super(BranchCLS, self).__init__() self.pool = nn.Sequential(nn.AdaptiveAvgPool2d((1,1)), Flatten() ) self.fc = nn.Linear(18, output_channels, bias=True) self.bn = nn.BatchNorm1d(18, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) self.branch_cls = nn.Sequential(nn.Conv2d(in_channels=in_channels, out_channels=32, padding=1, kernel_size=3, stride=1), nn.ReLU(inplace=True), nn.Conv2d(in_channels=32, out_channels=18, padding=1, kernel_size=3, stride=1), nn.ReLU(inplace=True), ) self.leakyrelu = nn.LeakyReLU(0.2) def forward(self, x): feat = self.branch_cls(x) x = self.pool(feat) x = self.bn(x) cls_res = self.fc(x) cls_pro = self.leakyrelu(cls_res) zero_vec = -9e15*torch.ones_like(cls_pro) cls_pro = torch.where(cls_pro > 0, cls_pro, zero_vec) return cls_res, cls_pro, feat class NLCDetection(nn.Module): def __init__(self): super(NLCDetection, self).__init__() self.split_tensor_1 = torch.tensor([1, 3]).cuda() self.split_tensor_2 = torch.tensor([1, 2, 1, 3]).cuda() self.softmax_m = nn.Softmax(dim=1) FENet_cfg = get_cfg_defaults() feat1_num, feat2_num, feat3_num, feat4_num = FENet_cfg['STAGE4']['NUM_CHANNELS'] ## mask generation branch. self.getmask = NonLocalMask(feat1_num, 4) ## classification branch. self.branch_cls_level_1 = BranchCLS(271, 14) # 252 + 18 = 270 self.branch_cls_level_2 = BranchCLS(252, 7) # 144+72+36 = 252 self.branch_cls_level_3 = BranchCLS(216, 5) # 144+72 = 216 self.branch_cls_level_4 = BranchCLS(144, 3) # 144 def forward(self, feat, img): s1, s2, s3, s4 = feat pconv_feat, mask, mask_binary = self.getmask(s1, img) pconv_feat = pconv_feat.clone().detach() pconv_1 = F.interpolate(pconv_feat, size=s1.size()[2:], mode='bilinear', align_corners=True) ## forth branch. cls_4, pro_4, _ = self.branch_cls_level_4(s4) cls_prob_4 = self.softmax_m(pro_4) cls_prob_40 = torch.unsqueeze(cls_prob_4[:,0],1) cls_prob_41 = torch.unsqueeze(cls_prob_4[:,1],1) cls_prob_42 = torch.unsqueeze(cls_prob_4[:,2],1) cls_prob_mask_3 = torch.cat([cls_prob_40, cls_prob_41, cls_prob_41, cls_prob_42, cls_prob_42],axis=1) ## third branch s4F = F.interpolate(s4, size=s3.size()[2:], mode='bilinear', align_corners=True) s3_input = torch.cat([s4F, s3], axis=1) cls_3, pro_3, _ = self.branch_cls_level_3(s3_input) cls_prob_3 = self.softmax_m(pro_3) cls_3 = cls_3 + cls_3 * cls_prob_mask_3 cls_prob_30 = torch.unsqueeze(cls_prob_3[:,0],1) cls_prob_31 = torch.unsqueeze(cls_prob_3[:,1],1) cls_prob_32 = torch.unsqueeze(cls_prob_3[:,2],1) cls_prob_33 = torch.unsqueeze(cls_prob_3[:,3],1) cls_prob_34 = torch.unsqueeze(cls_prob_3[:,4],1) cls_prob_mask_2 = torch.cat([cls_prob_30, cls_prob_31, cls_prob_31, cls_prob_32, cls_prob_32, cls_prob_33, cls_prob_34],axis=1) ## second branch s3F = F.interpolate(s3_input, size=s2.size()[2:], mode='bilinear', align_corners=True) s2_input = torch.cat([s3F, s2], axis=1) cls_2, pro_2, _ = self.branch_cls_level_2(s2_input) cls_prob_2 = self.softmax_m(pro_2) cls_2 = cls_2 + cls_2 * cls_prob_mask_2 cls_prob_20 = torch.unsqueeze(cls_prob_2[:,0],1) cls_prob_21 = torch.unsqueeze(cls_prob_2[:,1],1) cls_prob_22 = torch.unsqueeze(cls_prob_2[:,2],1) cls_prob_23 = torch.unsqueeze(cls_prob_2[:,3],1) cls_prob_24 = torch.unsqueeze(cls_prob_2[:,4],1) cls_prob_25 = torch.unsqueeze(cls_prob_2[:,4],1) cls_prob_26 = torch.unsqueeze(cls_prob_2[:,4],1) cls_prob_mask_1 = torch.cat([cls_prob_20, cls_prob_21, cls_prob_21, cls_prob_22, cls_prob_22, # 4 diffusion cls_prob_23, cls_prob_23, cls_prob_24, cls_prob_24, # 4 gan cls_prob_25, cls_prob_25, # faceshifter+stgan cls_prob_26, cls_prob_26, cls_prob_26], axis=1) # 3 editing s2F = F.interpolate(s2_input, size=s1.size()[2:], mode='bilinear', align_corners=True) s1_input = torch.cat([s2F, s1, pconv_1], axis=1) cls_1, pro_1, _ = self.branch_cls_level_1(s1_input) cls_1 = cls_1 + cls_1 * cls_prob_mask_1 return mask, mask_binary, cls_4, cls_3, cls_2, cls_1 ================================================ FILE: models/NLCDetection_loc.py ================================================ # ------------------------------------------------------------------------------ # Author: Xiao Guo (guoxia11@msu.edu) # CVPR2023: Hierarchical Fine-Grained Image Forgery Detection and Localization # ------------------------------------------------------------------------------ import torch import torch.nn as nn import torch.nn.functional as F from models.seg_hrnet_config import get_cfg_defaults import time def weights_init(init_type='gaussian'): def init_fun(m): classname = m.__class__.__name__ if (classname.find('Conv') == 0 or classname.find( 'Linear') == 0) and hasattr(m, 'weight'): if init_type == 'gaussian': nn.init.normal_(m.weight, 0.0, 0.02) elif init_type == 'xavier': nn.init.xavier_normal_(m.weight, gain=math.sqrt(2)) elif init_type == 'kaiming': nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in') elif init_type == 'orthogonal': nn.init.orthogonal_(m.weight, gain=math.sqrt(2)) elif init_type == 'default': pass else: assert 0, "Unsupported initialization: {}".format(init_type) if hasattr(m, 'bias') and m.bias is not None: nn.init.constant_(m.bias, 0.0) return init_fun class PartialConv(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True): super().__init__() self.input_conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias) self.mask_conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, False) self.input_conv.apply(weights_init('kaiming')) torch.nn.init.constant_(self.mask_conv.weight, 1.0) # mask is not updated for param in self.mask_conv.parameters(): param.requires_grad = False def forward(self, input, mask): # http://masc.cs.gmu.edu/wiki/partialconv # C(X) = W^T * X + b, C(0) = b, D(M) = 1 * M + 0 = sum(M) # W^T* (M .* X) / sum(M) + b = [C(M .* X) – C(0)] / D(M) + C(0) ## GX: masking the input outside function. output = self.input_conv(input) if self.input_conv.bias is not None: output_bias = self.input_conv.bias.view(1, -1, 1, 1).expand_as(output) else: output_bias = torch.zeros_like(output) with torch.no_grad(): output_mask = self.mask_conv(mask) no_update_holes = output_mask == 0 ## in output_mask, fills the 0-value-position with 1.0 ## without this step, math error occurs. mask_sum = output_mask.masked_fill_(no_update_holes, 1.0) output_pre = (output - output_bias) / mask_sum + output_bias output = output_pre.masked_fill_(no_update_holes, 0.0) new_mask = torch.ones_like(output) new_mask = new_mask.masked_fill_(no_update_holes, 0.0) return output, new_mask class NonLocalMask(nn.Module): def __init__(self, in_channels, reduce_scale): super(NonLocalMask, self).__init__() self.r = reduce_scale # input channel number self.ic = in_channels * self.r * self.r # middle channel number self.mc = self.ic self.g = nn.Conv2d(in_channels=self.ic, out_channels=self.ic, kernel_size=1, stride=1, padding=0) self.theta = nn.Conv2d(in_channels=self.ic, out_channels=self.mc, kernel_size=1, stride=1, padding=0) self.phi = nn.Conv2d(in_channels=self.ic, out_channels=self.mc, kernel_size=1, stride=1, padding=0) self.W_s = nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=1, stride=1, padding=0) self.gamma_s = nn.Parameter(torch.ones(1)) self.conv_1 = nn.Conv2d(in_channels=in_channels, out_channels=18, kernel_size=3, stride=1, padding=1) self.relu = nn.ReLU() self.conv_2 = nn.Conv2d(in_channels=18, out_channels=1, kernel_size=3, stride=1, padding=1) ## Pconv self.Pconv_1 = PartialConv(3, 3, kernel_size=3, stride=2) self.Pconv_2 = PartialConv(3, 3, kernel_size=3, stride=2) self.Pconv_3 = PartialConv(3, 1, kernel_size=3, stride=2) def forward(self, x, img): b, c, h, w = x.shape x1 = x.reshape(b, self.ic, h // self.r, w // self.r) # g x g_x = self.g(x1).view(b, self.ic, -1) g_x = g_x.permute(0, 2, 1) # theta theta_x = self.theta(x1).view(b, self.mc, -1) theta_x_s = theta_x.permute(0, 2, 1) # phi x phi_x = self.phi(x1).view(b, self.mc, -1) phi_x_s = phi_x # non-local attention f_s = torch.matmul(theta_x_s, phi_x_s) f_s_div = F.softmax(f_s, dim=-1) # get y_s y_s = torch.matmul(f_s_div, g_x) y_s = y_s.permute(0, 2, 1).contiguous() y_s = y_s.view(b, c, h, w) # GX: (256,256,18), output mask for the deep metric loss. mask_feat = x + self.gamma_s * self.W_s(y_s) # get 1-dimensional mask_tmp # mask_binary = self.getmask(mask_feat) mask_feat = self.conv_1(mask_feat) mask_binary = mask_feat mask_binary = self.relu(mask_binary) # print("mask_feat: ", mask_feat.size()) # torch.Size([4, 18, 256, 256]) mask_binary = self.conv_2(mask_binary) # print("mask_binary: ", mask_binary.size()) # torch.Size([4, 1, 256, 256]) mask_binary = torch.sigmoid(mask_binary) mask_tmp = mask_binary.repeat(1, 3, 1, 1) mask_img = img * mask_tmp # mask_img is the overlaid image. ## conv output x, new_mask = self.Pconv_1(mask_img, mask_tmp) x, new_mask = self.Pconv_2(x, new_mask) x, _ = self.Pconv_3(x, new_mask) mask_binary = mask_binary.squeeze(dim=1) return x, torch.sigmoid(mask_feat), mask_binary class Flatten(nn.Module): def __init__(self): super(Flatten, self).__init__() def forward(self, x): return x.view(x.size(0), -1) class Classifer(nn.Module): def __init__(self, in_channels, output_channels): super(Classifer, self).__init__() self.pool = nn.Sequential( # nn.AdaptiveAvgPool2d((1,1)), nn.AdaptiveAvgPool2d(1), Flatten() ) self.fc = nn.Linear(in_channels, output_channels, bias=True) self.relu = nn.ReLU(inplace=True) def forward(self, x): feat = self.pool(x) feat = self.relu(feat) cls_res = self.fc(feat) return cls_res class BranchCLS(nn.Module): def __init__(self, in_channels, output_channels): super(BranchCLS, self).__init__() self.pool = nn.Sequential(nn.AdaptiveAvgPool2d((1,1)), Flatten() ) self.fc = nn.Linear(18, output_channels, bias=True) self.bn = nn.BatchNorm1d(18, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) self.branch_cls = nn.Sequential(nn.Conv2d(in_channels=in_channels, out_channels=32, padding=1, kernel_size=3, stride=1), nn.ReLU(inplace=True), nn.Conv2d(in_channels=32, out_channels=18, padding=1, kernel_size=3, stride=1), nn.ReLU(inplace=True), ) self.leakyrelu = nn.LeakyReLU(0.2) def forward(self, x): feat = self.branch_cls(x) x = self.pool(feat) x = self.bn(x) cls_res = self.fc(x) cls_pro = self.leakyrelu(cls_res) zero_vec = -9e15*torch.ones_like(cls_pro) cls_pro = torch.where(cls_pro > 0, cls_pro, zero_vec) return cls_res, cls_pro, feat class FPN_loc(nn.Module): '''self-implementation Feature Pyramid Networks ''' def __init__(self, args, clip_dim=64, multi_feat=None): super(FPN_loc, self).__init__() ## obtain the dimensions. feat1_num, feat2_num, feat3_num, feat4_num = multi_feat self.smooth_s4 = nn.Sequential( nn.Conv2d(feat4_num, clip_dim, kernel_size=(1, 1), stride=(1, 1)), nn.Conv2d(clip_dim, clip_dim, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) self.smooth_s3 = nn.Sequential( nn.Conv2d(feat3_num, clip_dim, kernel_size=(1, 1), stride=(1, 1)), nn.Conv2d(clip_dim, clip_dim, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) self.smooth_s2 = nn.Sequential( nn.Conv2d(feat2_num, clip_dim, kernel_size=(1, 1), stride=(1, 1)), nn.Conv2d(clip_dim, clip_dim, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) self.smooth_s1 = nn.Sequential( nn.Conv2d(feat1_num, clip_dim, kernel_size=(1, 1), stride=(1, 1)), nn.Conv2d(clip_dim, clip_dim, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) ## new branch. self.fpn1 = nn.Sequential( nn.Conv2d(clip_dim, clip_dim, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False), nn.BatchNorm2d(clip_dim), nn.ReLU(), # nn.Upsample(scale_factor=2) ) self.fpn2 = nn.Sequential( nn.Conv2d(clip_dim, clip_dim, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False), nn.BatchNorm2d(clip_dim), nn.ReLU(), nn.Upsample(scale_factor=2) ) self.fpn3 = nn.Sequential( nn.Conv2d(clip_dim, clip_dim, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False), nn.BatchNorm2d(clip_dim), nn.ReLU(), nn.Upsample(scale_factor=2), ) self.fpn4 = nn.Sequential( nn.Conv2d(clip_dim, clip_dim, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False), nn.BatchNorm2d(clip_dim), nn.ReLU(), nn.Upsample(scale_factor=2), ) smooth_ops = [self.smooth_s4, self.smooth_s3, self.smooth_s2, self.smooth_s1] fpn_ops = [self.fpn4, self.fpn3, self.fpn2, self.fpn1] class NLCDetection(nn.Module): def __init__(self): super(NLCDetection, self).__init__() self.crop_size = (256, 256) self.split_tensor_1 = torch.tensor([1, 3]).cuda() self.split_tensor_2 = torch.tensor([1, 2, 1, 3]).cuda() self.softmax_m = nn.Softmax(dim=1) FENet_cfg = get_cfg_defaults() feat1_num, feat2_num, feat3_num, feat4_num = FENet_cfg['STAGE4']['NUM_CHANNELS'] ## mask generation branch. feat_dim = 64 # large clip_dim will ruin the space of Multi-branch-feature-extractor self.getmask = NonLocalMask(feat_dim, 4) self.FPN_LOC = FPN_loc(feat_dim, multi_feat=FENet_cfg['STAGE4']['NUM_CHANNELS']) ## classification branch. self.branch_cls_level_1 = BranchCLS(317, 14) # 252 + 64 = 316 self.branch_cls_level_2 = BranchCLS(252, 7) # 144+72+36 = 252 self.branch_cls_level_3 = BranchCLS(216, 5) # 144+72 = 216 self.branch_cls_level_4 = BranchCLS(144, 3) # 144 def feature_resize(self, feat): '''first obtain the mask via the progressive scheme.''' s1, s2, s3, s4 = feat s1 = F.interpolate(s1, size=self.crop_size, mode='bilinear', align_corners=True) s2 = F.interpolate(s2, size=[i // 2 for i in self.crop_size], mode='bilinear', align_corners=True) s3 = F.interpolate(s3, size=[i // 4 for i in self.crop_size], mode='bilinear', align_corners=True) s4 = F.interpolate(s4, size=[i // 8 for i in self.crop_size], mode='bilinear', align_corners=True) return s1, s2, s3, s4 def forward(self, feat, img): s1, s2, s3, s4 = self.feature_resize(feat) img = F.interpolate(img, size=self.crop_size, mode='bilinear', align_corners=True) feat_4 = self.FPN_LOC.smooth_s4(s4) feat_4 = self.FPN_LOC.fpn4(feat_4) feat_3 = self.FPN_LOC.smooth_s3(s3) feat_3 = self.FPN_LOC.fpn3(feat_3+feat_4) feat_2 = self.FPN_LOC.smooth_s2(s2) feat_2 = self.FPN_LOC.fpn2(feat_2+feat_3) feat_1 = self.FPN_LOC.smooth_s1(s1) s1 = self.FPN_LOC.fpn1(feat_1+feat_2) pconv_feat, mask, mask_binary = self.getmask(s1, img) pconv_feat = pconv_feat.clone().detach() pconv_1 = F.interpolate(pconv_feat, size=s1.size()[2:], mode='bilinear', align_corners=True) ## forth branch. cls_4, pro_4, _ = self.branch_cls_level_4(s4) cls_prob_4 = self.softmax_m(pro_4) cls_prob_40 = torch.unsqueeze(cls_prob_4[:,0],1) cls_prob_41 = torch.unsqueeze(cls_prob_4[:,1],1) cls_prob_42 = torch.unsqueeze(cls_prob_4[:,2],1) cls_prob_mask_3 = torch.cat([cls_prob_40, cls_prob_41, cls_prob_41, cls_prob_42, cls_prob_42],axis=1) ## third branch s4F = F.interpolate(s4, size=s3.size()[2:], mode='bilinear', align_corners=True) s3_input = torch.cat([s4F, s3], axis=1) cls_3, pro_3, _ = self.branch_cls_level_3(s3_input) cls_prob_3 = self.softmax_m(pro_3) cls_3 = cls_3 + cls_3 * cls_prob_mask_3 cls_prob_30 = torch.unsqueeze(cls_prob_3[:,0],1) cls_prob_31 = torch.unsqueeze(cls_prob_3[:,1],1) cls_prob_32 = torch.unsqueeze(cls_prob_3[:,2],1) cls_prob_33 = torch.unsqueeze(cls_prob_3[:,3],1) cls_prob_34 = torch.unsqueeze(cls_prob_3[:,4],1) cls_prob_mask_2 = torch.cat([cls_prob_30, cls_prob_31, cls_prob_31, cls_prob_32, cls_prob_32, cls_prob_33, cls_prob_34],axis=1) ## second branch s3F = F.interpolate(s3_input, size=s2.size()[2:], mode='bilinear', align_corners=True) s2_input = torch.cat([s3F, s2], axis=1) cls_2, pro_2, _ = self.branch_cls_level_2(s2_input) cls_prob_2 = self.softmax_m(pro_2) cls_2 = cls_2 + cls_2 * cls_prob_mask_2 cls_prob_20 = torch.unsqueeze(cls_prob_2[:,0],1) cls_prob_21 = torch.unsqueeze(cls_prob_2[:,1],1) cls_prob_22 = torch.unsqueeze(cls_prob_2[:,2],1) cls_prob_23 = torch.unsqueeze(cls_prob_2[:,3],1) cls_prob_24 = torch.unsqueeze(cls_prob_2[:,4],1) cls_prob_25 = torch.unsqueeze(cls_prob_2[:,4],1) cls_prob_26 = torch.unsqueeze(cls_prob_2[:,4],1) cls_prob_mask_1 = torch.cat([cls_prob_20, cls_prob_21, cls_prob_21, cls_prob_22, cls_prob_22, # 4 diffusion cls_prob_23, cls_prob_23, cls_prob_24, cls_prob_24, # 4 gan cls_prob_25, cls_prob_25, # faceshifter+stgan cls_prob_26, cls_prob_26, cls_prob_26], axis=1) # 3 editing s2F = F.interpolate(s2_input, size=s1.size()[2:], mode='bilinear', align_corners=True) s1_input = torch.cat([s2F, s1, pconv_1], axis=1) cls_1, pro_1, _ = self.branch_cls_level_1(s1_input) cls_1 = cls_1 + cls_1 * cls_prob_mask_1 mask = mask.squeeze(dim=1) return mask, mask_binary, cls_4, cls_3, cls_2, cls_1 ================================================ FILE: models/NLCDetection_pconv.py ================================================ # ------------------------------------------------------------------------------ # Author: Xiao Guo (guoxia11@msu.edu) # CVPR2023: Hierarchical Fine-Grained Image Forgery Detection and Localization # ------------------------------------------------------------------------------ import torch import torch.nn as nn import torch.nn.functional as F from models.seg_hrnet_config import get_cfg_defaults import time def weights_init(init_type='gaussian'): def init_fun(m): classname = m.__class__.__name__ if (classname.find('Conv') == 0 or classname.find( 'Linear') == 0) and hasattr(m, 'weight'): if init_type == 'gaussian': nn.init.normal_(m.weight, 0.0, 0.02) elif init_type == 'xavier': nn.init.xavier_normal_(m.weight, gain=math.sqrt(2)) elif init_type == 'kaiming': nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in') elif init_type == 'orthogonal': nn.init.orthogonal_(m.weight, gain=math.sqrt(2)) elif init_type == 'default': pass else: assert 0, "Unsupported initialization: {}".format(init_type) if hasattr(m, 'bias') and m.bias is not None: nn.init.constant_(m.bias, 0.0) return init_fun class PartialConv(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True): super().__init__() self.input_conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias) self.mask_conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, False) self.input_conv.apply(weights_init('kaiming')) torch.nn.init.constant_(self.mask_conv.weight, 1.0) # mask is not updated for param in self.mask_conv.parameters(): param.requires_grad = False def forward(self, input, mask): # http://masc.cs.gmu.edu/wiki/partialconv # C(X) = W^T * X + b, C(0) = b, D(M) = 1 * M + 0 = sum(M) # W^T* (M .* X) / sum(M) + b = [C(M .* X) – C(0)] / D(M) + C(0) ## GX: masking the input outside function. output = self.input_conv(input) if self.input_conv.bias is not None: output_bias = self.input_conv.bias.view(1, -1, 1, 1).expand_as(output) else: output_bias = torch.zeros_like(output) with torch.no_grad(): output_mask = self.mask_conv(mask) no_update_holes = output_mask == 0 ## in output_mask, fills the 0-value-position with 1.0 ## without this step, math error occurs. mask_sum = output_mask.masked_fill_(no_update_holes, 1.0) output_pre = (output - output_bias) / mask_sum + output_bias output = output_pre.masked_fill_(no_update_holes, 0.0) new_mask = torch.ones_like(output) new_mask = new_mask.masked_fill_(no_update_holes, 0.0) return output, new_mask class NonLocalMask(nn.Module): def __init__(self, in_channels, reduce_scale): super(NonLocalMask, self).__init__() self.r = reduce_scale # input channel number self.ic = in_channels * self.r * self.r # middle channel number self.mc = self.ic self.g = nn.Conv2d(in_channels=self.ic, out_channels=self.ic, kernel_size=1, stride=1, padding=0) self.theta = nn.Conv2d(in_channels=self.ic, out_channels=self.mc, kernel_size=1, stride=1, padding=0) self.phi = nn.Conv2d(in_channels=self.ic, out_channels=self.mc, kernel_size=1, stride=1, padding=0) self.W_s = nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=1, stride=1, padding=0) self.gamma_s = nn.Parameter(torch.ones(1)) self.getmask = nn.Sequential( nn.Conv2d(in_channels=in_channels, out_channels=16, kernel_size=3, stride=1, padding=1), nn.ReLU(), nn.Conv2d(in_channels=16, out_channels=1, kernel_size=3, stride=1, padding=1) ) ## Pconv self.Pconv_1 = PartialConv(3, 3, kernel_size=3, stride=2) self.Pconv_2 = PartialConv(3, 3, kernel_size=3, stride=2) self.Pconv_3 = PartialConv(3, 1, kernel_size=3, stride=2) def forward(self, x, img): b, c, h, w = x.shape x1 = x.reshape(b, self.ic, h // self.r, w // self.r) # g x g_x = self.g(x1).view(b, self.ic, -1) g_x = g_x.permute(0, 2, 1) # theta theta_x = self.theta(x1).view(b, self.mc, -1) theta_x_s = theta_x.permute(0, 2, 1) # phi x phi_x = self.phi(x1).view(b, self.mc, -1) phi_x_s = phi_x # non-local attention f_s = torch.matmul(theta_x_s, phi_x_s) f_s_div = F.softmax(f_s, dim=-1) # get y_s y_s = torch.matmul(f_s_div, g_x) y_s = y_s.permute(0, 2, 1).contiguous() y_s = y_s.view(b, c, h, w) # GX: (256,256,18), output mask for the deep metric loss. mask_feat = x + self.gamma_s * self.W_s(y_s) # get 1-dimensional mask_tmp mask_binary = torch.sigmoid(self.getmask(mask_feat)) mask_tmp = mask_binary.repeat(1, 3, 1, 1) mask_img = img * mask_tmp # mask_img is the overlaid image. ## conv output x, new_mask = self.Pconv_1(mask_img, mask_tmp) x, new_mask = self.Pconv_2(x, new_mask) x, _ = self.Pconv_3(x, new_mask) mask_binary = mask_binary.squeeze(dim=1) return x, mask_feat, mask_binary class Flatten(nn.Module): def __init__(self): super(Flatten, self).__init__() def forward(self, x): return x.view(x.size(0), -1) class Classifer(nn.Module): def __init__(self, in_channels, output_channels): super(Classifer, self).__init__() self.pool = nn.Sequential( # nn.AdaptiveAvgPool2d((1,1)), nn.AdaptiveAvgPool2d(1), Flatten() ) self.fc = nn.Linear(in_channels, output_channels, bias=True) self.relu = nn.ReLU(inplace=True) def forward(self, x): feat = self.pool(x) feat = self.relu(feat) cls_res = self.fc(feat) return cls_res class BranchCLS(nn.Module): def __init__(self, in_channels, output_channels): super(BranchCLS, self).__init__() self.pool = nn.Sequential(nn.AdaptiveAvgPool2d((1,1)), Flatten() ) self.fc = nn.Linear(18, output_channels, bias=True) self.bn = nn.BatchNorm1d(18, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) self.branch_cls = nn.Sequential(nn.Conv2d(in_channels=in_channels, out_channels=32, padding=1, kernel_size=3, stride=1), nn.ReLU(inplace=True), nn.Conv2d(in_channels=32, out_channels=18, padding=1, kernel_size=3, stride=1), nn.ReLU(inplace=True), ) self.leakyrelu = nn.LeakyReLU(0.2) def forward(self, x): feat = self.branch_cls(x) x = self.pool(feat) x = self.bn(x) cls_res = self.fc(x) cls_pro = self.leakyrelu(cls_res) zero_vec = -9e15*torch.ones_like(cls_pro) cls_pro = torch.where(cls_pro > 0, cls_pro, zero_vec) return cls_res, cls_pro, feat class NLCDetection(nn.Module): def __init__(self, args): super(NLCDetection, self).__init__() self.crop_size = args.crop_size self.split_tensor_1 = torch.tensor([1, 3]).cuda() self.split_tensor_2 = torch.tensor([1, 2, 1, 3]).cuda() self.softmax_m = nn.Softmax(dim=1) FENet_cfg = get_cfg_defaults() feat1_num, feat2_num, feat3_num, feat4_num = FENet_cfg['STAGE4']['NUM_CHANNELS'] ## mask generation branch. self.getmask = NonLocalMask(feat1_num, 4) ## classification branch. self.branch_cls_level_1 = BranchCLS(271, 14) # 252 + 18 = 270 self.branch_cls_level_2 = BranchCLS(252, 7) # 144+72+36 = 252 self.branch_cls_level_3 = BranchCLS(216, 5) # 144+72 = 216 self.branch_cls_level_4 = BranchCLS(144, 3) # 144 def forward(self, feat, img): s1, s2, s3, s4 = feat # mask_binary is intermediate result, to ignore. pconv_feat, mask, mask_binary = self.getmask(s1, img) pconv_feat = pconv_feat.clone().detach() pconv_1 = F.interpolate(pconv_feat, size=s1.size()[2:], mode='bilinear', align_corners=True) ## forth branch. cls_4, pro_4, _ = self.branch_cls_level_4(s4) cls_prob_4 = self.softmax_m(pro_4) cls_prob_40 = torch.unsqueeze(cls_prob_4[:,0],1) cls_prob_41 = torch.unsqueeze(cls_prob_4[:,1],1) cls_prob_42 = torch.unsqueeze(cls_prob_4[:,2],1) cls_prob_mask_3 = torch.cat([cls_prob_40, cls_prob_41, cls_prob_41, cls_prob_42, cls_prob_42],axis=1) ## third branch s4F = F.interpolate(s4, size=s3.size()[2:], mode='bilinear', align_corners=True) s3_input = torch.cat([s4F, s3], axis=1) cls_3, pro_3, _ = self.branch_cls_level_3(s3_input) cls_prob_3 = self.softmax_m(pro_3) cls_3 = cls_3 + cls_3 * cls_prob_mask_3 cls_prob_30 = torch.unsqueeze(cls_prob_3[:,0],1) cls_prob_31 = torch.unsqueeze(cls_prob_3[:,1],1) cls_prob_32 = torch.unsqueeze(cls_prob_3[:,2],1) cls_prob_33 = torch.unsqueeze(cls_prob_3[:,3],1) cls_prob_34 = torch.unsqueeze(cls_prob_3[:,4],1) cls_prob_mask_2 = torch.cat([cls_prob_30, cls_prob_31, cls_prob_31, cls_prob_32, cls_prob_32, cls_prob_33, cls_prob_34],axis=1) ## second branch s3F = F.interpolate(s3_input, size=s2.size()[2:], mode='bilinear', align_corners=True) s2_input = torch.cat([s3F, s2], axis=1) cls_2, pro_2, _ = self.branch_cls_level_2(s2_input) cls_prob_2 = self.softmax_m(pro_2) cls_2 = cls_2 + cls_2 * cls_prob_mask_2 cls_prob_20 = torch.unsqueeze(cls_prob_2[:,0],1) cls_prob_21 = torch.unsqueeze(cls_prob_2[:,1],1) cls_prob_22 = torch.unsqueeze(cls_prob_2[:,2],1) cls_prob_23 = torch.unsqueeze(cls_prob_2[:,3],1) cls_prob_24 = torch.unsqueeze(cls_prob_2[:,4],1) cls_prob_25 = torch.unsqueeze(cls_prob_2[:,4],1) cls_prob_26 = torch.unsqueeze(cls_prob_2[:,4],1) cls_prob_mask_1 = torch.cat([cls_prob_20, cls_prob_21, cls_prob_21, cls_prob_22, cls_prob_22, # 4 diffusion cls_prob_23, cls_prob_23, cls_prob_24, cls_prob_24, # 4 gan cls_prob_25, cls_prob_25, # faceshifter+stgan cls_prob_26, cls_prob_26, cls_prob_26], axis=1) # 3 editing s2F = F.interpolate(s2_input, size=s1.size()[2:], mode='bilinear', align_corners=True) s1_input = torch.cat([s2F, s1, pconv_1], axis=1) cls_1, pro_1, _ = self.branch_cls_level_1(s1_input) cls_1 = cls_1 + cls_1 * cls_prob_mask_1 return mask, mask_binary, cls_4, cls_3, cls_2, cls_1 ================================================ FILE: models/hrnet_w18_small_v2.pth ================================================ [File too large to display: 15.3 MB] ================================================ FILE: models/seg_hrnet.py ================================================ # ------------------------------------------------------------------------------ # Copyright (c) Microsoft # Licensed under the MIT License. # The script is adopted from Ke Sun (sunk@mail.ustc.edu.cn) # ------------------------------------------------------------------------------ from __future__ import absolute_import from __future__ import division from __future__ import print_function from .LaPlacianMs import LaPlacianMs from .NLCDetection_pconv import weights_init import os import logging import functools import numpy as np import torch import torch.nn as nn import torch._utils import torch.nn.functional as F BN_MOMENTUM = 0.01 logger = logging.getLogger(__name__) # noise generation def srm_generation(image): """ :param image: N * C * H * W :return: noises """ # srm kernel 1 srm1 = np.zeros([5, 5]).astype('float32') srm1[1:-1, 1:-1] = np.array([[-1, 2, -1], [2, -4, 2], [-1, 2, -1]]) srm1 /= 4. # srm kernel 2 srm2 = np.array([[-1, 2, -2, 2, -1], [2, -6, 8, -6, 2], [-2, 8, -12, 8, -2], [2, -6, 8, -6, 2], [-1, 2, -2, 2, -1]]).astype('float32') srm2 /= 12. # srm kernel 3 srm3 = np.zeros([5, 5]).astype('float32') srm3[2, 1:-1] = np.array([1, -2, 1]) srm3 /= 2. srm = np.stack([srm1, srm2, srm3], axis=0) W_srm = np.zeros([3, 3, 5, 5]).astype('float32') for i in range(3): W_srm[i, 0, :, :] = srm[i, :, :] W_srm[i, 1, :, :] = srm[i, :, :] W_srm[i, 2, :, :] = srm[i, :, :] W_srm = torch.from_numpy(W_srm).to(image.get_device()) srm_noise = F.conv2d(image, W_srm, padding=2) return srm_noise # bayar constrained layer class BayarConstraint(object): def __init__(self): pass def __call__(self, module): if hasattr(module, 'weight'): weight = module.weight.data # oc, ic, h, w h, w = weight.size()[2:] mask = torch.zeros_like(weight) mask[:, :, h//2, w//2] = 1 weight *= (1 - mask) rest_sum = torch.sum(weight, dim=(2, 3), keepdim=True) weight /= (rest_sum + 1e-7) weight -= mask module.weight.data = weight def conv3x3(in_planes, out_planes, stride=1): """3x3 convolution with padding""" return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) class CatDepth(nn.Module): def __init__(self): super(CatDepth, self).__init__() def forward(self, x, y): return torch.cat([x,y],dim=1) '''GX: basicblock contains two conv3x3 and two batch norm''' '''GX: at last, it has a residual connection''' class BasicBlock(nn.Module): expansion = 1 def __init__(self, inplanes, planes, stride=1, downsample=None): super(BasicBlock, self).__init__() self.conv1 = conv3x3(inplanes, planes, stride) self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) self.relu = nn.ReLU(inplace=False) self.conv2 = conv3x3(planes, planes) self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) self.downsample = downsample self.stride = stride def forward(self, x): residual = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) if self.downsample is not None: residual = self.downsample(x) out = out + residual out = self.relu(out) return out '''GX: 3 conv + 3 bn then a residual.''' class Bottleneck(nn.Module): expansion = 4 def __init__(self, inplanes, planes, stride=1, downsample=None): super(Bottleneck, self).__init__() self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) self.bn3 = nn.BatchNorm2d(planes * self.expansion, momentum=BN_MOMENTUM) self.relu = nn.ReLU(inplace=False) self.downsample = downsample self.stride = stride def forward(self, x): residual = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) out = self.relu(out) out = self.conv3(out) out = self.bn3(out) if self.downsample is not None: residual = self.downsample(x) out = out + residual out = self.relu(out) return out '''GX: the basic component in the network.''' class HighResolutionModule(nn.Module): def __init__(self, num_branches, blocks, num_blocks, num_inchannels, num_channels, fuse_method, multi_scale_output=True): super(HighResolutionModule, self).__init__() self._check_branches( num_branches, blocks, num_blocks, num_inchannels, num_channels) self.num_inchannels = num_inchannels self.fuse_method = fuse_method self.num_branches = num_branches self.multi_scale_output = multi_scale_output self.branches = self._make_branches( num_branches, blocks, num_blocks, num_channels) self.fuse_layers = self._make_fuse_layers() self.relu = nn.ReLU(inplace=False) def _check_branches(self, num_branches, blocks, num_blocks, num_inchannels, num_channels): if num_branches != len(num_blocks): error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format( num_branches, len(num_blocks)) raise ValueError(error_msg) if num_branches != len(num_channels): error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format( num_branches, len(num_channels)) raise ValueError(error_msg) if num_branches != len(num_inchannels): error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format( num_branches, len(num_inchannels)) raise ValueError(error_msg) def _make_one_branch(self, branch_index, block, num_blocks, num_channels, stride=1): downsample = None if stride != 1 or \ self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion: downsample = nn.Sequential( nn.Conv2d(self.num_inchannels[branch_index], num_channels[branch_index] * block.expansion, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(num_channels[branch_index] * block.expansion, momentum=BN_MOMENTUM), ) layers = [] layers.append(block(self.num_inchannels[branch_index], num_channels[branch_index], stride, downsample)) self.num_inchannels[branch_index] = \ num_channels[branch_index] * block.expansion for i in range(1, num_blocks[branch_index]): layers.append(block(self.num_inchannels[branch_index], num_channels[branch_index])) return nn.Sequential(*layers) def _make_branches(self, num_branches, block, num_blocks, num_channels): branches = [] for i in range(num_branches): branches.append( self._make_one_branch(i, block, num_blocks, num_channels)) return nn.ModuleList(branches) ## GX: fuse layer converts feature maps at different resolution branches ## GX: into the feature map of the new branches' feature map. ## GX: https://zhuanlan.zhihu.com/p/335333233 def _make_fuse_layers(self): if self.num_branches == 1: return None num_branches = self.num_branches num_inchannels = self.num_inchannels fuse_layers = [] for i in range(num_branches if self.multi_scale_output else 1): fuse_layer = [] for j in range(num_branches): if j > i: fuse_layer.append(nn.Sequential( nn.Conv2d(num_inchannels[j], num_inchannels[i], 1, 1, 0, bias=False), nn.BatchNorm2d(num_inchannels[i], momentum=BN_MOMENTUM))) elif j == i: fuse_layer.append(None) else: conv3x3s = [] for k in range(i - j): if k == i - j - 1: num_outchannels_conv3x3 = num_inchannels[i] conv3x3s.append(nn.Sequential( nn.Conv2d(num_inchannels[j], num_outchannels_conv3x3, 3, 2, 1, bias=False), nn.BatchNorm2d(num_outchannels_conv3x3, momentum=BN_MOMENTUM))) else: num_outchannels_conv3x3 = num_inchannels[j] conv3x3s.append(nn.Sequential( nn.Conv2d(num_inchannels[j], num_outchannels_conv3x3, 3, 2, 1, bias=False), nn.BatchNorm2d(num_outchannels_conv3x3, momentum=BN_MOMENTUM), nn.ReLU(inplace=False))) fuse_layer.append(nn.Sequential(*conv3x3s)) fuse_layers.append(nn.ModuleList(fuse_layer)) return nn.ModuleList(fuse_layers) def get_num_inchannels(self): return self.num_inchannels def forward(self, x): if self.num_branches == 1: return [self.branches[0](x[0])] for i in range(self.num_branches): x[i] = self.branches[i](x[i]) x_fuse = [] for i in range(len(self.fuse_layers)): y = x[0] if i == 0 else self.fuse_layers[i][0](x[0]) for j in range(1, self.num_branches): if i == j: y = y + x[j] elif j > i: width_output = x[i].shape[-1] height_output = x[i].shape[-2] y = y + F.interpolate( self.fuse_layers[i][j](x[j]), size=[height_output, width_output], mode='bilinear', align_corners=True) else: y = y + self.fuse_layers[i][j](x[j]) x_fuse.append(self.relu(y)) return x_fuse blocks_dict = { 'BASIC': BasicBlock, 'BOTTLENECK': Bottleneck } ## GX: the HighResolutionNet has 4 stages. ## GX: each stage has one module which is HighResolutionModule. ## GX: HighResolutionModule has 1,2,3,4 branches. ## GX: each stage has a transitional layers in between. class HighResolutionNet(nn.Module): def __init__(self, config, **kwargs): super(HighResolutionNet, self).__init__() # noise conv # self.im_conv = nn.Conv2d(3, 10, kernel_size=3, stride=1, padding=1, bias=False) # self.bayar_conv = nn.Conv2d(3, 3, kernel_size=5, stride=1, padding=2, bias=False) # self.constraints = BayarConstraint() # stem net self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM) self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM) self.relu = nn.ReLU(inplace=False) # frequency branch self.conv1fre = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) self.bn1fre = nn.BatchNorm2d(64, momentum=BN_MOMENTUM) self.conv2fre = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False) self.bn2fre = nn.BatchNorm2d(64, momentum=BN_MOMENTUM) self.laplacian = LaPlacianMs(in_c=64,gauss_ker_size=3,scale=[2,4,8]) # concat self.concat_depth = CatDepth() self.conv_1x1_merge = nn.Sequential(nn.Conv2d(128, 64, kernel_size=1, stride=1, bias=False,groups=2), nn.BatchNorm2d(64), nn.ReLU(inplace=True), nn.Dropout(p=0.2) ) # self.module_initializer = module_initializer() # self.conv_1x1_merge = self.module_initializer(self.conv_1x1_merge) self.conv_1x1_merge.apply(weights_init('kaiming')) self.stage1_cfg = config['STAGE1'] num_channels = self.stage1_cfg['NUM_CHANNELS'][0] block = blocks_dict[self.stage1_cfg['BLOCK']] num_blocks = self.stage1_cfg['NUM_BLOCKS'][0] self.layer1 = self._make_layer(block, 64, num_channels, num_blocks) stage1_out_channel = block.expansion * num_channels self.stage2_cfg = config['STAGE2'] num_channels = self.stage2_cfg['NUM_CHANNELS'] block = blocks_dict[self.stage2_cfg['BLOCK']] num_channels = [ num_channels[i] * block.expansion for i in range(len(num_channels))] self.transition1 = self._make_transition_layer( [stage1_out_channel], num_channels) self.stage2, pre_stage_channels = self._make_stage( self.stage2_cfg, num_channels) self.stage3_cfg = config['STAGE3'] num_channels = self.stage3_cfg['NUM_CHANNELS'] block = blocks_dict[self.stage3_cfg['BLOCK']] num_channels = [ num_channels[i] * block.expansion for i in range(len(num_channels))] self.transition2 = self._make_transition_layer( pre_stage_channels, num_channels) self.stage3, pre_stage_channels = self._make_stage( self.stage3_cfg, num_channels) self.stage4_cfg = config['STAGE4'] num_channels = self.stage4_cfg['NUM_CHANNELS'] block = blocks_dict[self.stage4_cfg['BLOCK']] num_channels = [ num_channels[i] * block.expansion for i in range(len(num_channels))] self.transition3 = self._make_transition_layer( pre_stage_channels, num_channels) self.stage4, pre_stage_channels = self._make_stage( self.stage4_cfg, num_channels, multi_scale_output=True) last_inp_channels = np.int(np.sum(pre_stage_channels)) ## GX: one dimension matrix converts pre to pos. ## GX: if channel numbers are equal, pass it directly. ## GX: if channel numbers are different, using conv 3x3. ## GX: https://zhuanlan.zhihu.com/p/335333233 def _make_transition_layer(self, num_channels_pre_layer, num_channels_cur_layer): num_branches_cur = len(num_channels_cur_layer) num_branches_pre = len(num_channels_pre_layer) transition_layers = [] for i in range(num_branches_cur): if i < num_branches_pre: if num_channels_cur_layer[i] != num_channels_pre_layer[i]: transition_layers.append(nn.Sequential( nn.Conv2d(num_channels_pre_layer[i], num_channels_cur_layer[i], 3, 1, 1, bias=False), nn.BatchNorm2d( num_channels_cur_layer[i], momentum=BN_MOMENTUM), nn.ReLU(inplace=False))) else: transition_layers.append(None) else: conv3x3s = [] for j in range(i + 1 - num_branches_pre): inchannels = num_channels_pre_layer[-1] outchannels = num_channels_cur_layer[i] \ if j == i - num_branches_pre else inchannels conv3x3s.append(nn.Sequential( nn.Conv2d( inchannels, outchannels, 3, 2, 1, bias=False), nn.BatchNorm2d(outchannels, momentum=BN_MOMENTUM), nn.ReLU(inplace=False))) transition_layers.append(nn.Sequential(*conv3x3s)) return nn.ModuleList(transition_layers) ## GX: _make_layer creates a conv + bn def _make_layer(self, block, inplanes, planes, blocks, stride=1): downsample = None if stride != 1 or inplanes != planes * block.expansion: downsample = nn.Sequential( nn.Conv2d(inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM), ) layers = [] layers.append(block(inplanes, planes, stride, downsample)) inplanes = planes * block.expansion for i in range(1, blocks): layers.append(block(inplanes, planes)) return nn.Sequential(*layers) def _make_stage(self, layer_config, num_inchannels, multi_scale_output=True): ## GX: num_modules are all 1 in this work. ## GX: light-weight architectures: num_blocks are all 0. ## GX: branch numbers are 2, 3, 4. num_modules = layer_config['NUM_MODULES'] num_branches = layer_config['NUM_BRANCHES'] num_blocks = layer_config['NUM_BLOCKS'] num_channels = layer_config['NUM_CHANNELS'] block = blocks_dict[layer_config['BLOCK']] fuse_method = layer_config['FUSE_METHOD'] modules = [] for i in range(num_modules): # multi_scale_output is only used last module if not multi_scale_output and i == num_modules - 1: reset_multi_scale_output = False else: reset_multi_scale_output = True modules.append( HighResolutionModule(num_branches, block, num_blocks, num_inchannels, num_channels, fuse_method, reset_multi_scale_output) ) num_inchannels = modules[-1].get_num_inchannels() return nn.Sequential(*modules), num_inchannels def forward(self, x): x_fre = self.conv1fre(x) x_fre = self.bn1fre(x_fre) x_fre = self.relu(x_fre) x_fre = self.laplacian(x_fre) x_fre = self.conv2fre(x_fre) x_fre = self.bn2fre(x_fre) x_fre = self.relu(x_fre) x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.conv2(x) x = self.bn2(x) x = self.relu(x) x = self.concat_depth(x, x_fre) x = self.conv_1x1_merge(x) x = self.layer1(x) x_list = [] for i in range(self.stage2_cfg['NUM_BRANCHES']): if self.transition1[i] is not None: x_list.append(self.transition1[i](x)) else: x_list.append(x) y_list = self.stage2(x_list) x_list = [] for i in range(self.stage3_cfg['NUM_BRANCHES']): if self.transition2[i] is not None: if i < self.stage2_cfg['NUM_BRANCHES']: x_list.append(self.transition2[i](y_list[i])) else: x_list.append(self.transition2[i](y_list[-1])) else: x_list.append(y_list[i]) y_list = self.stage3(x_list) x_list = [] for i in range(self.stage4_cfg['NUM_BRANCHES']): if self.transition3[i] is not None: if i < self.stage3_cfg['NUM_BRANCHES']: x_list.append(self.transition3[i](y_list[i])) else: x_list.append(self.transition3[i](y_list[-1])) else: x_list.append(y_list[i]) x = self.stage4(x_list) return x def init_weights(self, pretrained='',): for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.normal_(m.weight, std=0.001) elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) if os.path.isfile(pretrained): pretrained_dict = torch.load(pretrained) ## GX: official pre-trained dict. print('=> loading HRNet pretrained model {}'.format(pretrained)) model_dict = self.state_dict() model_pretrained_lst, model_nopretrained_lst = [], [] ## GX: model_dict is weights from the current architecture. pretrained_dict_used = {} ## GX: gather weights from pretrained_dict to model_dict. nopretrained_dict = {k: v for k, v in model_dict.items()} for k, v in model_dict.items(): pretrained_key = 'model.' + k if pretrained_key not in pretrained_dict.keys(): if 'stage2' in pretrained_key and 'fuse_layers' not in pretrained_key: if 'branches.2' in pretrained_key: pretrained_key = pretrained_key.replace('stage2.0.', 'stage3.0.') elif 'branches.3' in pretrained_key: pretrained_key = pretrained_key.replace('stage2.0.', 'stage4.0.') elif 'stage3' in pretrained_key and 'fuse_layers' not in pretrained_key: pretrained_key = pretrained_key.replace('stage3.0.', 'stage4.0.') elif 'fre' in pretrained_key: pretrained_key = pretrained_key.replace('fre', '') if pretrained_key in pretrained_dict.keys(): pretrained_dict_used[k] = pretrained_dict[pretrained_key] nopretrained_dict.pop(k) print("no pretrain dict length is: ", len(nopretrained_dict)) print("pretrained dict length is: ", len(pretrained_dict)) model_dict.update(pretrained_dict_used) self.load_state_dict(model_dict) def get_seg_model(cfg, **kwargs): model = HighResolutionNet(cfg, **kwargs) model.init_weights(cfg.PRETRAINED) return model ================================================ FILE: models/seg_hrnet_config.py ================================================ # ------------------------------------------------------------------------------ # Copyright (c) Microsoft # Licensed under the MIT License. # The script is adopted from Ke Sun (sunk@mail.ustc.edu.cn) # ------------------------------------------------------------------------------ from __future__ import absolute_import from __future__ import division from __future__ import print_function from yacs.config import CfgNode as CN # high_resoluton_net related params for segmentation HRNET = CN() HRNET.PRETRAINED_LAYERS = ['*'] HRNET.STEM_INPLANES = 64 HRNET.FINAL_CONV_KERNEL = 1 HRNET.PRETRAINED = 'models/hrnet_w18_small_v2.pth' HRNET.STAGE1 = CN() HRNET.STAGE1.NUM_MODULES = 1 HRNET.STAGE1.NUM_BRANCHES = 1 HRNET.STAGE1.NUM_BLOCKS = [2] HRNET.STAGE1.NUM_CHANNELS = [64] HRNET.STAGE1.BLOCK = 'BOTTLENECK' HRNET.STAGE1.FUSE_METHOD = 'SUM' HRNET.STAGE2 = CN() HRNET.STAGE2.NUM_MODULES = 1 HRNET.STAGE2.NUM_BRANCHES = 4 HRNET.STAGE2.NUM_BLOCKS = [2, 2, 2, 2] HRNET.STAGE2.NUM_CHANNELS = [18, 36, 72, 144] HRNET.STAGE2.BLOCK = 'BASIC' HRNET.STAGE2.FUSE_METHOD = 'SUM' HRNET.STAGE3 = CN() HRNET.STAGE3.NUM_MODULES = 1 HRNET.STAGE3.NUM_BRANCHES = 4 HRNET.STAGE3.NUM_BLOCKS = [2, 2, 2, 2] HRNET.STAGE3.NUM_CHANNELS = [18, 36, 72, 144] HRNET.STAGE3.BLOCK = 'BASIC' HRNET.STAGE3.FUSE_METHOD = 'SUM' HRNET.STAGE4 = CN() HRNET.STAGE4.NUM_MODULES = 1 HRNET.STAGE4.NUM_BRANCHES = 4 HRNET.STAGE4.NUM_BLOCKS = [2, 2, 2, 2] HRNET.STAGE4.NUM_CHANNELS = [18, 36, 72, 144] HRNET.STAGE4.BLOCK = 'BASIC' HRNET.STAGE4.FUSE_METHOD = 'SUM' def get_cfg_defaults(): """Get a yacs CfgNode object with default values for my_project.""" # Return a clone so that the defaults will not be altered # This is for the "local variable" use pattern return HRNET.clone() ================================================ FILE: utils/custom_loss.py ================================================ # ------------------------------------------------------------------------------ # Author: Xiao Guo (guoxia11@msu.edu) # CVPR2023: Hierarchical Fine-Grained Image Forgery Detection and Localization # ------------------------------------------------------------------------------ import os import time import torch import torch.nn as nn from tqdm import tqdm device = torch.device('cuda:0') device_ids = [0] class IsolatingLossFunction(torch.nn.Module): def __init__(self, c, R, p=2, threshold_val=1.85): super().__init__() self.c = c.clone().detach() # Center of the hypershpere, c ∈ ℝ^d (d-dimensional real-valued vector) self.R = R.clone().detach() # Radius of the hypersphere, R ∈ ℝ^1 (Real-valued) self.p = p # norm value (p-norm), p ∈ ℝ^1 (Default 2) self.margin_natu = (0.15)*self.R self.margin_mani = (2.5)*self.R self.threshold = threshold_val*self.R print('\n') print('$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$') print('$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$') print(f'The Radius manipul is {self.margin_natu}.') print(f'The Radius expansn is {self.margin_mani}.') print(f'The Radius threshold is {self.threshold}.') print('$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$') print('$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$') print('\n') self.pdist = torch.nn.PairwiseDistance(p=self.p) # Creating a Pairwise Distance object self.dis_curBatch = 0 self.dis = 0 def forward(self, model_output, label, threshold_new=None, update_flag=None): '''output the distance mask and compute the loss.''' bs, feat_dim, w, h = model_output.size() model_output = model_output.permute(0,2,3,1) model_output = torch.reshape(model_output, (-1, feat_dim)) dist = self.pdist(model_output, self.c) self.dist = dist pred_mask = torch.gt(self.dist, self.threshold).to(torch.float32) pred_mask = torch.reshape(pred_mask, (bs,w,h,1)).permute(0,3,1,2) self.dist = torch.reshape(self.dist, (bs,w,h,1)).permute(0,3,1,2) self.dis_curBatch = pred_mask.to(device).to(torch.float32) label = torch.reshape(label, (bs*w*h,1)) label_sum = label.sum().item() label_nat = torch.eq(label,0) label_mani = torch.eq(label,1) assert dist.size() == label_nat[:,0].size() assert dist.size() == label_mani[:,0].size() label_nat_sum = label_nat.sum().item() label_mani_sum = label_mani.sum().item() dist_nat = torch.masked_select(dist, label_nat[:,0]) dist_mani = torch.max(torch.tensor(0).to(device).float(), torch.sub(self.margin_mani, torch.masked_select(dist, label_mani[:,0])) ) loss_nat = dist_nat.sum()/label_nat_sum if label_nat_sum != 0 else \ torch.tensor(0).to(device).float() loss_mani = dist_mani.sum()/label_mani_sum if label_mani_sum != 0 else \ torch.tensor(0).to(device).float() loss_total = loss_nat + loss_mani return loss_total.to(device), loss_mani.to(device), loss_nat.to(device) def inference(self, model_output): '''output the distance for the final binary mask.''' bs, feat_dim, w, h = model_output.size() model_output = model_output.permute(0,2,3,1) model_output = torch.reshape(model_output, (-1, feat_dim)) dist = self.pdist(model_output, self.c) self.dist = dist pred_mask = torch.gt(self.dist, self.threshold).to(torch.float32) pred_mask = torch.reshape(pred_mask, (bs,w,h,1)).permute(0,3,1,2) self.dist = torch.reshape(self.dist, (bs,w,h,1)).permute(0,3,1,2) self.dis_curBatch = pred_mask.to(device).to(torch.float32) return self.dis_curBatch.squeeze(dim=1), self.dist.squeeze(dim=1) def center_radius_init(args, FENet, SegNet, train_data_loader, debug=True, center=None): '''the center is the mean-value of pixel features of the real pixels''' sample_num = 0 center = torch.zeros(18).to(device) FENet.eval() SegNet.eval() with torch.no_grad(): for batch_id, train_data in enumerate(tqdm(train_data_loader, desc="compute center")): image, masks, cls, fcls, scls, tcls = train_data if batch_id % 10 != 0: continue mask_cls = fcls.eq(0) image_selected = image[mask_cls,:] if image_selected.size()[0] == 0: continue else: sample_num += image_selected.size()[0] mask1 = masks[0].to(device) image_selected = image_selected.to(device) cls = cls.to(device) mask1_fea = FENet(image_selected) mask1_fea, _, _, _, _, _ = SegNet(mask1_fea, image_selected) mask1_fea = torch.mean(mask1_fea,(0,2,3)) center += mask1_fea center = center/sample_num pdist = torch.nn.PairwiseDistance(2) radius = torch.tensor(0, dtype=torch.float32).to(device) with torch.no_grad(): for batch_id, train_data in enumerate(tqdm(train_data_loader, desc="compute radius")): if batch_id % 10 != 0: continue image, masks, cls, fcls, scls, tcls = train_data mask1 = masks[0].to(device) image = image.to(device) fcls = fcls.to(device) mask_cls = fcls.eq(0) image_selected = image[mask_cls,:] if image_selected.size()[0] == 0: continue mask1_fea = FENet(image_selected) mask1_fea, _, _, _, _, _ = SegNet(mask1_fea, image_selected) bs, channel, h, w = mask1_fea.size() mask1_fea = mask1_fea.permute(0,2,3,1) mask1_fea = torch.reshape(mask1_fea, (bs*w*h, -1)) dist = pdist(mask1_fea, center) dist_max = torch.max(dist) if radius < dist_max: radius = dist_max return center, radius def load_center_radius(args, FENet, SegNet, train_data_loader, center_radius_dir='center'): '''loading the pre-computed center and radius.''' center_radius_path = os.path.join(center_radius_dir, 'radius_center.pth') if os.path.exists(center_radius_path): load_dict_center_radius = torch.load(center_radius_path) center = load_dict_center_radius['center'] radius = load_dict_center_radius['radius'] center, radius = center.to(device), radius.to(device) else: os.makedirs(center_radius_dir, exist_ok=True) center, radius = center_radius_init(args, FENet, SegNet, train_data_loader, debug=True) torch.save({'center': center, 'radius': radius}, center_radius_path) return center, radius def load_center_radius_api(center_radius_dir='center'): '''loading the pre-computed center and radius.''' center_radius_path = os.path.join(center_radius_dir, 'radius_center.pth') load_dict_center_radius = torch.load(center_radius_path) center = load_dict_center_radius['center'] radius = load_dict_center_radius['radius'] center, radius = center.to(device), radius.to(device) return center, radius ================================================ FILE: utils/load_data.py ================================================ # ------------------------------------------------------------------------------ # Author: Xiao Guo (guoxia11@msu.edu) # CVPR2023: Hierarchical Fine-Grained Image Forgery Detection and Localization # ------------------------------------------------------------------------------ from os.path import isfile, join from PIL import Image from torchvision import transforms import numpy as np import abc import cv2 import torch.utils.data as data import torch.nn.functional as F import random random.seed(1234567890) from random import randrange import torch.nn as nn import torch import imageio import time import math import torch class BaseData(data.Dataset): ''' The dataset used for the IFDL dataset. ''' def __init__(self, args): super(BaseData, self).__init__() self.crop_size = args.crop_size self.file_path = '/user/guoxia11/cvlshare/cvl-guoxia11/IMDL/REAL' self.file_path_fake = '/user/guoxia11/cvlshare/cvl-guoxia11/IMDL/FAKE' # Real and Fake images. self.image_names = [] self.image_class = self._img_list_retrieve() for idx, _ in enumerate(self.image_class): self.image_names += _ def __getitem__(self, index): res = self.get_item(index) return res def __len__(self): return len(self.image_names) @abc.abstractmethod def _img_list_retrieve(): pass def _resize_func(self, input_img): '''resize the input image into the crop size.''' input_img = Image.fromarray(input_img) input_img = input_img.resize(self.crop_size, resample=Image.BICUBIC) input_img = np.asarray(input_img) return input_img def get_image(self, image_name, aug_index=None): '''transform the image.''' image = imageio.imread(image_name) if image.shape[-1] == 4: image = self.rgba2rgb(image) image = self._resize_func(image) image = image.astype(np.float32) / 255. image = torch.from_numpy(image) return image.permute(2, 0, 1) def rgba2rgb(self, rgba, background=(255, 255, 255)): ''' turn rgba to rgb. ''' row, col, ch = rgba.shape rgb = np.zeros((row, col, 3), dtype='float32') r, g, b, a = rgba[:, :, 0], rgba[:, :, 1], rgba[:, :, 2], rgba[:, :, 3] a = np.asarray(a, dtype='float32') / 255.0 R, G, B = background rgb[:, :, 0] = r * a + (1.0 - a) * R rgb[:, :, 1] = g * a + (1.0 - a) * G rgb[:, :, 2] = b * a + (1.0 - a) * B return np.asarray(rgb, dtype='uint8') def generate_4masks(self, mask): '''generate 4 masks at different scale.''' crop_height, crop_width = self.crop_size ma_height, ma_width = mask.shape[:2] mask_pil = Image.fromarray(mask) if ma_height != crop_height or ma_width != crop_width: mask_pil = mask_pil.resize(self.crop_size, resample=Image.BICUBIC) (width2, height2) = (mask_pil.width // 2, mask_pil.height // 2) (width3, height3) = (mask_pil.width // 4, mask_pil.height // 4) (width4, height4) = (mask_pil.width // 8, mask_pil.height // 8) mask2 = mask_pil.resize((width2, height2)) mask3 = mask_pil.resize((width3, height3)) mask4 = mask_pil.resize((width4, height4)) mask = np.asarray(mask_pil) mask = mask.astype(np.float32) / 255 mask[mask > 0.5] = 1 mask[mask <= 0.5] = 0 mask2 = np.asarray(mask2).astype(np.float32) / 255 mask2[mask2 > 0.5] = 1 mask2[mask2 <= 0.5] = 0 mask3 = np.asarray(mask3).astype(np.float32) / 255 mask3[mask3 > 0.5] = 1 mask3[mask3 <= 0.5] = 0 mask4 = np.asarray(mask4).astype(np.float32) / 255 mask4[mask4 > 0.5] = 1 mask4[mask4 <= 0.5] = 0 mask = torch.from_numpy(mask) mask2 = torch.from_numpy(mask2) mask3 = torch.from_numpy(mask3) mask4 = torch.from_numpy(mask4) # print(mask.size(), mask2.size(), mask3.size(), mask4.size()) return mask, mask2, mask3, mask4 def get_mask(self, image_name, cls, aug_index=None): '''given the cls, we return the mask.''' # authentic if cls in [0,1,2,3,4]: mask = self.load_mask('', real=True, aug_index=aug_index) return_res = [0,0,0,0] # splice elif cls == 5: if '.jpg' in image_name: mask_name = image_name.replace('fake', 'mask').replace('.jpg', '.png') else: mask_name = image_name.replace('fake', 'mask').replace('.tif', '.png') mask = self.load_mask(mask_name, aug_index=aug_index) return_res = [1,1,1,cls - 4] # inpainting elif cls == 6: mask_name = image_name.replace('/fake/', '/mask/').replace('.jpg', '.png') mask = self.load_mask(mask_name, aug_index=aug_index) return_res = [1,1,1,cls - 4] # copy-move elif cls == 7: mask_name = image_name.replace('.png', '_mask.png') mask_name = mask_name.replace('CopyMove', 'CopyMove_mask') mask = self.load_mask(mask_name, aug_index=aug_index) return_res = [1,1,1,cls - 4] # faceshifter elif cls == 8: image_id = image_name.split('/')[-1].split('.')[0] mask_name = image_name.replace(image_id, f'mask/{image_id}_mask') mask = self.load_mask(mask_name, aug_index=aug_index) return_res = [1,2,2,cls - 4] # STGAN elif cls == 9: image_id = image_name.split('/')[-1].split('.')[0] mask_name = image_name.replace('fake', 'mask').replace(image_id, f'{image_id}_label') mask = self.load_mask(mask_name, aug_index=aug_index) return_res = [1,2,2,cls - 4] ## they are star2, hisd, stylegan2, stylegan3, ddpm, ddim, latent, guided elif cls in [10,11,12,13,14,15,16,17]: mask = self.load_mask('', real=False, full_syn=True, aug_index=aug_index) if cls in [10,11]: return_res = [2,3,3,cls-4] elif cls in [12,13]: return_res = [2,3,4,cls-4] elif cls in [14,15]: return_res = [2,4,5,cls-4] elif cls in [16,17]: return_res = [2,4,6,cls-4] else: print(cls, index) raise Exception('class is not defined!') return mask, return_res def load_mask(self, mask_name, real=False, full_syn=False, gray=True, aug_index=None): '''binarize the mask, given the mask_name.''' if real: mask = np.zeros(self.crop_size) else: if not full_syn: mask = imageio.imread(mask_name) if not gray else np.asarray(Image.open(mask_name).convert('RGB').convert('L')) mask = mask.astype(np.float32) else: mask = np.ones(self.crop_size) mask = self.generate_4masks(mask) return mask def get_cls(self, image_name): '''return the forgery/authentic cls given the image_name.''' if '/authentic/' in image_name: return_cls = 0 elif '/REAL/LSUN/' in image_name: return_cls = 0 elif '/afhq_v2/' in image_name: return_cls = 1 elif '/CelebAHQ/' in image_name: return_cls = 2 elif '/FFHQ/' in image_name: return_cls = 3 elif '/Youtube' in image_name: return_cls = 4 elif '/splice' in image_name: return_cls = 5 elif '/Inpainting' in image_name: return_cls = 6 elif '/CopyMove' in image_name: return_cls = 7 elif '/FaShifter' in image_name: return_cls = 8 elif '/STGAN' in image_name: return_cls = 9 elif '/Star2' in image_name: return_cls = 10 elif '/HiSD' in image_name: return_cls = 11 elif '/STYL2' in image_name: return_cls = 12 elif '/STYL3' in image_name: return_cls = 13 elif '/DDPM_' in image_name: return_cls = 14 elif '/DDIM_' in image_name: return_cls = 15 elif '/D_latent' in image_name: return_cls = 16 elif '/GLIDE/' in image_name: return_cls = 17 else: print(image_name) raise ValueError return return_cls class TrainData(BaseData): ''' The dataset used for the IFDL dataset. ''' def __init__(self, args): self.is_train = True self.val_num = 90000 super(TrainData, self).__init__(args) def img_retrieve(self, file_text, file_folder, real=True): ''' Parameters: file_text: str, text file for images. file_folder: str, images folder. Returns: the image list. ''' result_list = [] val_num = self.val_num * 3 if file_text in ["Youtube", "FaShifter"] else self.val_num data_path = self.file_path if real else self.file_path_fake data_text = join(data_path, file_text) data_path = join(data_path, file_folder) file_handler = open(data_text) contents = file_handler.readlines() if self.is_train: contents_lst = contents[:val_num] else: contents_lst = contents[val_num:] for content in contents_lst: if '.npy' not in content and 'mask' not in content: img_name = content.strip() img_name = join(data_path, img_name) result_list.append(img_name) file_handler.close() ## only truncate the val_num images. if len(result_list) < val_num: mul_factor = (val_num//len(result_list)) + 2 result_list = result_list * mul_factor result_list = result_list[-val_num:] return result_list def get_item(self, index): ''' given the index, this function returns the image with the forgery mask this function calls get_image, get_mask for the image and mask torch tensor. ''' image_name = self.image_names[index] cls = self.get_cls(image_name) # image and mask aug_index = randrange(0, 8) image = self.get_image(image_name, aug_index) mask, return_res = self.get_mask(image_name, cls, aug_index) return image, mask, return_res[0], return_res[1], return_res[2], return_res[3] def _img_list_retrieve(self): '''Returns image list for different authentic and forgery image.''' authentic_names = self.img_retrieve('authentic.txt', 'authentic') splice_names = self.img_retrieve('splice_randmask.txt', 'splice_randmask/fake',False) inpainting_names = self.img_retrieve('Inpainting.txt', 'Inpainting/fake', False) copymove_names = self.img_retrieve('copy_move.txt', 'CopyMove', False) STGAN_names = self.img_retrieve('STGAN.txt', 'STGAN/fake', False) FaShifter_names = self.img_retrieve('FaShifter.txt', 'FaShifter', False) return [authentic_names, splice_names, inpainting_names, copymove_names, STGAN_names, FaShifter_names] class ValData(BaseData): ''' The dataset used for the IFDL dataset. ''' def __init__(self, args): self.is_train = False self.val_num = 900 super(ValData, self).__init__(args) def img_retrieve(self, file_text, file_folder, real=True): ''' Parameters: file_text: str, text file for images. file_folder: str, images folder. Returns: the image list. ''' result_list = [] val_num = self.val_num * 3 if file_text in ["Youtube", "FaShifter"] else self.val_num data_path = self.file_path if real else self.file_path_fake data_text = join(data_path, file_text) data_path = join(data_path, file_folder) file_handler = open(data_text) contents = file_handler.readlines() for content in contents[-val_num:]: if '.npy' not in content and 'mask' not in content: img_name = content.strip() img_name = join(data_path, img_name) result_list.append(img_name) file_handler.close() ## only truncate the val_num images. if len(result_list) < val_num: mul_factor = (val_num//len(result_list)) + 2 result_list = result_list * mul_factor result_list = result_list[-val_num:] return result_list def get_item(self, index): ''' given the index, this function returns the image with the forgery mask this function calls get_image, get_mask for the image and mask torch tensor. ''' image_name = self.image_names[index] cls = self.get_cls(image_name) # image image = self.get_image(image_name) mask, return_res = self.get_mask(image_name, cls) return image, mask, return_res[0], return_res[1], return_res[2], return_res[3], image_name def _img_list_retrieve(self): '''Returns image list for different authentic and forgery image.''' STGAN_names = self.img_retrieve('STGAN.txt', 'STGAN/fake', False) FaShifter_names = self.img_retrieve('FaShifter.txt', 'FaShifter', False) return [STGAN_names, FaShifter_names] ================================================ FILE: utils/load_edata.py ================================================ from PIL import Image from torchvision import transforms from os.path import join import abc import numpy as np import torch import torch.utils.data as data import imageio import os class BaseData(data.Dataset): ''' The dataset used for the IFDL dataset. ''' def __init__(self, args): super(BaseData, self).__init__() self.crop_size = args.crop_size ## demo dataset: self.mani_data_dir = './data_dir' ## the full dataset: # self.mani_data_dir = './data' self.image_names = [] self.image_class = [] self.mask_names = [] def __getitem__(self, index): res = self.get_item(index) return res def __len__(self): return len(self.image_names) def generate_mask(self, mask): ''' generate the corresponding binary mask. ''' mask = mask.astype(np.float32) / 255 mask[mask > 0.5] = 1 mask[mask <= 0.5] = 0 mask = np.expand_dims(mask, axis=0) mask = torch.from_numpy(mask) return mask def rgba2rgb(self, rgba, background=(255, 255, 255)): ''' turn rgba to rgb. ''' row, col, ch = rgba.shape rgb = np.zeros((row, col, 3), dtype='float32') r, g, b, a = rgba[:, :, 0], rgba[:, :, 1], rgba[:, :, 2], rgba[:, :, 3] a = np.asarray(a, dtype='float32') / 255.0 R, G, B = background rgb[:, :, 0] = r * a + (1.0 - a) * R rgb[:, :, 1] = g * a + (1.0 - a) * G rgb[:, :, 2] = b * a + (1.0 - a) * B return np.asarray(rgb, dtype='uint8') # the output value is uint8 that belongs to [0,255] def get_image(self, image_name): ''' return the image with the tensor. ''' image = imageio.imread(image_name) if len(image.shape) == 2: image = imageio.imread(image_name, as_gray=False, pilmode="RGB") if image.shape[-1] == 4: image = self.rgba2rgb(image) image = torch.from_numpy(image.astype(np.float32) / 255) return image.permute(2, 0, 1) def get_mask(self, mask_name): ''' return the binary mask. ''' mask = Image.open(mask_name).convert('L') mask = mask.resize(self.crop_size, resample=Image.BICUBIC) mask = np.asarray(mask) mask = self.generate_mask(mask) return mask @abc.abstractmethod def get_item(self, index): ''' blur image = Image.fromarray(image) image = image.filter(ImageFilter.GaussianBlur(radius=7)) image = np.asarray(image) resize image = Image.fromarray(image) image = image.resize((int(image.width*0.25), int(image.height*0.25)), resample=Image.BILINEAR) image = np.asarray(image) noise import skimage image = skimage.util.random_noise(image/255., mode='gaussian', mean=0, var=15/255) * 255 jpeg compression im = Image.open(image_name) temp_name = './temp/' + image_name.split('/')[-1][:-3] + 'jpg' im.save(temp_name, 'JPEG', quality=50) image = Image.open(temp_name) image = np.asarray(image) ''' pass class ValColumbia(BaseData): def __init__(self, args): super(ValColumbia, self).__init__(args) ddir = os.path.join(self.mani_data_dir, 'columbia') with open(join(ddir, 'vallist.txt')) as f: contents = f.readlines() for content in contents: _ = os.path.join(ddir, '4cam_splc', content.strip()) self.image_names.append(_) self.image_class = [1] * len(self.image_names) def get_item(self, index): image_name = self.image_names[index] cls = self.image_class[index] # image image = self.get_image(image_name) # mask if '4cam_splc' in image_name: mask_name = image_name.replace('4cam_splc', 'mask').replace('.tif', '.jpg') mask = self.get_mask(mask_name) else: mask = np.zeros((1, 256, 256), dtype='float32') return image, mask, cls, image_name class ValCoverage(BaseData): def __init__(self, args): super(ValCoverage, self).__init__(args) ddir = os.path.join(self.mani_data_dir, 'Coverage') with open(join(ddir, 'fake.txt')) as f: contents = f.readlines() for content in contents: _ = os.path.join(ddir, 'image', content.strip()) self.image_names.append(_) self.image_class = [2] * len(self.image_names) def get_item(self, index): image_name = self.image_names[index] cls = self.image_class[index] ## read image. image = self.get_image(image_name) # mask mask_name = image_name.replace('image', 'mask').replace('t.tif', 'forged.tif') mask = self.get_mask(mask_name) return image, mask, cls, image_name class ValCasia(BaseData): def __init__(self, args): super(ValCasia, self).__init__(args) ddir = os.path.join(self.mani_data_dir, 'CASIA/CASIA1') with open(join(ddir, 'fake.txt')) as f: contents = f.readlines() for content in contents: tag = content.split('/')[-1].split('_')[1] if tag == 'D': self.image_class.append(1) elif tag == 'S': self.image_class.append(2) else: raise Exception('unknown class: {}'.format(content)) self.image_names.append(os.path.join(ddir, 'fake', content.strip())) ddir = os.path.join(self.mani_data_dir, 'CASIA/CASIA2') with open(join(ddir, 'fake.txt')) as f: contents = f.readlines() for content in contents: tag = content.split('/')[-1].split('_')[1] if tag == 'D': self.image_class.append(1) elif tag == 'S': self.image_class.append(2) else: raise Exception('unknown class: {}'.format(content)) self.image_names.append(os.path.join(ddir, 'fake', content.strip())) def get_item(self, index): image_name = self.image_names[index] cls = self.image_class[index] # image image = self.get_image(image_name) # mask if '.jpg' in image_name: mask_name = image_name.replace('fake', 'mask').replace('.jpg', '_gt.png') else: mask_name = image_name.replace('fake', 'mask').replace('.tif', '_gt.png') mask = self.get_mask(mask_name) return image, mask, cls, image_name class ValNIST16(BaseData): def __init__(self, args): super(ValNIST16, self).__init__(args) ddir = os.path.join(self.mani_data_dir, 'NIST16') file_name = 'alllist.txt' with open(join(ddir, file_name)) as f: contents = f.readlines() for content in contents: image_name, mask_name = content.split(' ') self.image_names.append(join(ddir, image_name)) self.mask_names.append(join(ddir, mask_name.strip())) def get_item(self, index): image_name = self.image_names[index] mask_name = self.mask_names[index] if 'splice' in mask_name: cls = 1 elif 'manipulation' in mask_name: cls = 2 elif 'remove' in mask_name: cls = 3 else: cls = 0 # image image = self.get_image(image_name) if image.size()[2]*image.size()[1] >= 1000*1000: image = imageio.imread(image_name) if image.shape[-1] == 4: image = self.rgba2rgb(image) image = Image.fromarray(image) image = image.resize((1000, 1000), resample=Image.BICUBIC) image = np.asarray(image) image = torch.from_numpy(image.astype(np.float32) / 255) image = image.permute(2, 0, 1) # mask mask = self.get_mask(mask_name) mask = torch.abs(mask - 1) return image, mask, cls, image_name class ValIMD2020(BaseData): def __init__(self, args): super(ValIMD2020, self).__init__(args) ddir = os.path.join(self.mani_data_dir, 'IMD2020') file_name = 'fake.txt' with open(join(ddir, file_name)) as f: contents = f.readlines() for content in contents: image_name = content.strip() if '.jpg' in image_name: mask_name = image_name.replace('.jpg', '_mask.png') else: mask_name = image_name.replace('.png', '_mask.png') self.image_names.append(join(ddir, 'fake_img', image_name)) self.mask_names.append(join(ddir, 'mask', mask_name)) self.image_class = [2] * len(self.image_names) def get_item(self, index): image_name = self.image_names[index] mask_name = self.mask_names[index] cls = self.image_class[index] try: image = self.get_image(image_name) except: print(f"Fail at {image_name}.") mask = self.get_mask(mask_name) return image, mask, cls, image_name ================================================ FILE: utils/utils.py ================================================ # ------------------------------------------------------------------------------ # Author: Xiao Guo (guoxia11@msu.edu) # CVPR2023: Hierarchical Fine-Grained Image Forgery Detection and Localization # ------------------------------------------------------------------------------ import os import time import torch import torch.nn as nn import numpy as np import torch.nn.functional as F import matplotlib.pyplot as plt from torch.optim.lr_scheduler import ReduceLROnPlateau from kmeans_pytorch import kmeans from torchvision import transforms from torch.utils.data import DataLoader from sklearn import metrics from torchvision.utils import make_grid from einops import rearrange from PIL import Image Softmax_m = nn.Softmax(dim=1) device = torch.device('cuda:0') def device_ids_return(cuda_list): '''return the device id''' if len(cuda_list) == 1: device_ids = [0] elif len(cuda_list) == 2: device_ids = [0,1] elif len(cuda_list) == 3: device_ids = [0,1,2] elif len(cuda_list) == 4: device_ids = [0,1,2,3] elif len(cuda_list) == 7: device_ids = [0,1,2,3,4,5,6] return device_ids def findLastCheckpoint(save_dir): if os.path.exists(save_dir): file_list = os.listdir(save_dir) result = 0 for file in file_list: try: num = int(file.split('.')[0].split('_')[-1]) result = max(result, num) except: continue return result else: os.mkdir(save_dir) return 0 def get_confusion_matrix(y_true, y_pred): return metrics.confusion_matrix(y_true, y_pred) def compute_cls_acc_f1(label_lst, pred_lst, iter_num, tb_writer, descr='Val/level3_1'): F1 = metrics.f1_score(label_lst, pred_lst, average='macro') acc = metrics.accuracy_score(label_lst, pred_lst) tb_writer.add_scalar(f'{descr}_F1', F1, iter_num) tb_writer.add_scalar(f'{descr}_acc', acc, iter_num) print(f"In {descr}, the image-level Acc: {acc:.3f}, F1: {F1:.3f}.") print("******************************************************") return F1, acc def tb_writer_display(writer, iter_num, lr_scheduler, epoch, seg_accu, loc_map_loss, manipul_loss, natural_loss, binary_loss, loss_1, loss_2, loss_3, loss_4): writer.add_scalar('Train/seg_accu', seg_accu, iter_num) writer.add_scalar('Train/map_loss', loc_map_loss, iter_num) writer.add_scalar('Train/binary_map_loss', binary_loss, iter_num) writer.add_scalar('Train/manip_loss', manipul_loss, iter_num) writer.add_scalar('Train/natur_loss', natural_loss, iter_num) writer.add_scalar('Train/loss_1', loss_1, iter_num) writer.add_scalar('Train/loss_2', loss_2, iter_num) writer.add_scalar('Train/loss_3', loss_3, iter_num) writer.add_scalar('Train/loss_4', loss_3, iter_num) for count, gp in enumerate(lr_scheduler.optimizer.param_groups,1): writer.add_scalar('progress/lr_%d'%count, gp['lr'], iter_num) writer.add_scalar('progress/epoch', epoch, iter_num) writer.add_scalar('progress/curr_patience',lr_scheduler.num_bad_epochs,iter_num) writer.add_scalar('progress/patience',lr_scheduler.patience,iter_num) def one_hot_label(vector, Softmax_m=Softmax_m): x = Softmax_m(vector) x = torch.argmax(x, dim=1) return x def one_hot_label_new(vector, Softmax_m=Softmax_m): ''' compute the probability for being as the synthesized image (TODO: double check). ''' x = Softmax_m(vector) indices = torch.argmax(x, dim=1) prob = 1 - x[:,0] indices = list(indices.cpu().numpy()) prob = list(prob.cpu().numpy()) return indices, prob def level_1_convert(input_lst): res_lst = [] for _ in input_lst: if _ == 0: res_lst.append(0) else: res_lst.append(1) return res_lst def confusion_matrix_display(label_lst, res_lst, display_lst, display_name): confusion_matrix = metrics.confusion_matrix(label_lst, res_lst) cm_display = metrics.ConfusionMatrixDisplay(confusion_matrix = confusion_matrix, display_labels = display_lst) cm_display.plot() plt.savefig(f'{display_name}.png') confusion_matrix = metrics.confusion_matrix(label_lst, res_lst, normalize='true') cm_display = metrics.ConfusionMatrixDisplay(confusion_matrix = confusion_matrix, display_labels = display_lst) cm_display.plot() plt.savefig(f'{display_name}_normalized.png') def make_folder(folder_name): if not os.path.exists(folder_name): os.makedirs(folder_name, exist_ok=True) print(f"Making folder {folder_name}.") else: print(f"Folder {folder_name} exists.") def class_weight(mask, mask_idx): '''balance the weight on real and forgery pixel.''' mask_balance = torch.ones_like(mask).to(torch.float) if (mask == 1).sum(): mask_balance[mask == 1] = 0.5 / ((mask == 1).sum().to(torch.float) / mask.numel()) mask_balance[mask == 0] = 0.5 / ((mask == 0).sum().to(torch.float) / mask.numel()) else: pass # print(f'Mask{mask_idx} balance is not working!') return mask.to(device), mask_balance.to(device) def setup_optimizer(args, SegNet, FENet): '''setup the optimizier, which applies different learning rate on different layers.''' '''different hyper-parameters are changed towards HiFi-IFDL dataset.''' params_dict_list = [] params_dict_list.append({'params' : SegNet.module.parameters(), 'lr' : args.learning_rate}) freq_list = [] para_list = [] for name, param in FENet.named_parameters(): if 'fre' in name: freq_list += [param] else: para_list += [param] params_dict_list.append({'params' : freq_list, 'lr' : args.learning_rate*args.lr_backbone}) params_dict_list.append({'params' : para_list, 'lr' : args.learning_rate}) optimizer = torch.optim.Adam(params_dict_list, lr=args.learning_rate*0.75, weight_decay=1e-06) lr_scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=args.step_factor, min_lr=1e-08, patience=args.patience, verbose=True) return optimizer, lr_scheduler def restore_weight_helper(model, model_dir, initial_epoch): '''load model given the model_dir that has the model weights.''' try: weight_path = '{}/{}.pth'.format(model_dir, initial_epoch) state_dict = torch.load(weight_path, map_location='cuda:0')['model'] model.load_state_dict(state_dict) print('{} weight-loading succeeds: {}'.format(model_dir, weight_path)) except: print('{} weight-loading fails'.format(model_dir)) pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print("{}_params: {}".format(model_dir, pytorch_total_params)) return model def restore_optimizer(optimizer, model_dir): '''restore the optimizer.''' try: weight_path = '{}/{}.pth'.format(model_dir, initial_epoch) state_dict = torch.load(weight_path, map_location='cuda:0') print('Optimizer weight-loading succeeds.') optimizer.load_state_dict(state_dict['optimizer']) except: # print('{} Optimizer weight-loading fails.') pass return optimizer def composite_obj(args, loss, loss_1, loss_2, loss_3, loss_4, loss_binary): ''' 'base', 'fg', 'local', 'full' ''' if args.ablation == 'full': # fine-grained + localization loss_total = 100*loss + loss_1 + loss_2 + loss_3 + 100*loss_4 + loss_binary elif args.ablation == 'base': # one-shot loss_total = loss_4 elif args.ablation == 'fg': # only fine-grained loss_total = loss_1 + loss_2 + loss_3 + loss_4 elif args.ablation == 'local': # only loclization loss_total = loss + 10e-6*(loss_1 + loss_2 + loss_3 + loss_4) else: assert False return loss_total def composite_obj_step(args, loss_4_sum, map_loss_sum): ''' return loss for the scheduler ''' if args.ablation == 'full': schedule_step_loss = loss_4_sum + map_loss_sum elif args.ablation == 'base': schedule_step_loss = loss_4_sum elif args.ablation == 'fg': schedule_step_loss = loss_4_sum elif args.ablation == 'local': schedule_step_loss = map_loss_sum else: assert False return schedule_step_loss def viz_log(args, mask, pred_mask, image, iter_num, step, mode='train'): '''viz training, val and inference.''' mask = torch.unsqueeze(mask, dim=1) pred_mask = torch.unsqueeze(pred_mask, dim=1) mask_viz = torch.cat([mask]*3, axis=1) pred_mask = torch.cat([pred_mask]*3, axis=1) image = torch.nn.functional.interpolate(image, # for viz. size=(256, 256), mode='bilinear') fig_viz = torch.cat([mask_viz, image, pred_mask], axis=0) grid = make_grid(fig_viz, nrow=mask_viz.shape[0]) # nrow in fact is the column number. grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy() img_h = Image.fromarray(grid.astype(np.uint8)) # os.makedirs(f"./viz_{mode}_{args.learning_rate}/", exist_ok=True) os.makedirs(f"./viz_{mode}/", exist_ok=True) if mode == 'train': # img_h.save(f"./viz_{mode}_{args.learning_rate}/iter_{iter_num}.jpg") img_h.save(f"./viz_{mode}/iter_{iter_num}.jpg") else: # img_h.save(f"./viz_{mode}_{args.learning_rate}/iter_{iter_num}_step_{step}.jpg") img_h.save(f"./viz_{mode}/iter_{iter_num}_step_{step}.jpg") def process_mask(mask, pred_mask): '''process the mask''' pred_mask = torch.unsqueeze(pred_mask, dim=1) mask = torch.unsqueeze(mask, dim=1) pred_mask = torch.cat([pred_mask]*3, axis=1) mask = torch.cat([mask]*3, axis=1) pred_mask = nn.functional.interpolate(pred_mask, size=(256, 256), mode='bilinear') mask = nn.functional.interpolate(mask, size=(256, 256), mode='bilinear') return pred_mask, mask def viz_logs_scale(args, iter_num, mask_128, mask_64, mask_32, mask2, mask3, mask4, mode='train'): '''visualize the mask and predicted mask.''' pred_mask_128, mask128 = process_mask(mask_128, mask2) pred_mask_64, mask64 = process_mask(mask_64, mask3) pred_mask_32, mask32 = process_mask(mask_32, mask4) fig_viz = torch.cat([pred_mask_32, mask32, pred_mask_64, mask64, pred_mask_128, mask128], axis=0) grid = make_grid(fig_viz, nrow=pred_mask_32.shape[0]) grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy() img_h = Image.fromarray(grid.astype(np.uint8)) os.makedirs(f"./viz_{mode}_{args.learning_rate}/", exist_ok=True) img_h.save(f"./viz_{mode}_{args.learning_rate}/iter_{iter_num}_pred.jpg") def train_log_dump(args, seg_correct, seg_total, map_loss_sum, mani_lss_sum, natu_lss_sum, binary_map_loss_sum, loss_1_sum, loss_2_sum, loss_3_sum, loss_4_sum, epoch, step, writer, iter_num, lr_scheduler): '''compute and output the different training loss & seg accuarcy.''' seg_accu = seg_correct / seg_total * 100 loc_map_loss = map_loss_sum / args.dis_step manipul_loss = mani_lss_sum / args.dis_step natural_loss = natu_lss_sum / args.dis_step binary_loss = binary_map_loss_sum / args.dis_step loss_1 = loss_1_sum / args.dis_step loss_2 = loss_2_sum / args.dis_step loss_3 = loss_3_sum / args.dis_step loss_4 = loss_4_sum / args.dis_step print(f'[Epoch: {epoch+1}, Step: {step + 1}] batch_loc_acc: {seg_accu:.2f}') print(f'cls1_loss: {loss_1:.3f}, cls2_loss: {loss_2:.3f}, cls3_loss: {loss_3:.3f}, '+ f'cls4_loss: {loss_4:.3f}, map_loss: {loc_map_loss:.3f}, '+ f'manip_loss: {manipul_loss:.3f}, natur_loss: {natural_loss:.3f}, '+ f'binary_map_loss: {binary_loss:.3f}') '''write the tensorboard.''' tb_writer_display(writer, iter_num, lr_scheduler, epoch, seg_accu, loc_map_loss, manipul_loss, natural_loss, binary_loss, loss_1, loss_2, loss_3, loss_4) ================================================ FILE: weights/put_weights_here ================================================