diff --git a/README.md b/README.md index fb81205..0d341e8 100644 --- a/README.md +++ b/README.md @@ -56,7 +56,7 @@ conda activate arena With your virtual env activated, you can install `dfd-arena`: ```bash -cd dfd-arena && pip install -e +cd dfd-arena && pip install -e . ``` ## Usage diff --git a/arena/detectors/NPR/NPR.png b/arena/architectures/NPR/NPR.png similarity index 100% rename from arena/detectors/NPR/NPR.png rename to arena/architectures/NPR/NPR.png diff --git a/arena/detectors/NPR/README.md b/arena/architectures/NPR/README.md similarity index 100% rename from arena/detectors/NPR/README.md rename to arena/architectures/NPR/README.md diff --git a/arena/detectors/NPR/__init__.py b/arena/architectures/NPR/__init__.py similarity index 100% rename from arena/detectors/NPR/__init__.py rename to arena/architectures/NPR/__init__.py diff --git a/arena/detectors/NPR/config/constants.py b/arena/architectures/NPR/config/constants.py similarity index 100% rename from arena/detectors/NPR/config/constants.py rename to arena/architectures/NPR/config/constants.py diff --git a/arena/detectors/NPR/download_dataset.sh b/arena/architectures/NPR/download_dataset.sh similarity index 100% rename from arena/detectors/NPR/download_dataset.sh rename to arena/architectures/NPR/download_dataset.sh diff --git a/arena/detectors/NPR/networks/__init__.py b/arena/architectures/NPR/networks/__init__.py similarity index 100% rename from arena/detectors/NPR/networks/__init__.py rename to arena/architectures/NPR/networks/__init__.py diff --git a/arena/detectors/NPR/networks/base_model.py b/arena/architectures/NPR/networks/base_model.py similarity index 100% rename from arena/detectors/NPR/networks/base_model.py rename to arena/architectures/NPR/networks/base_model.py diff --git a/arena/detectors/NPR/networks/resnet.py b/arena/architectures/NPR/networks/resnet.py similarity index 100% rename from arena/detectors/NPR/networks/resnet.py rename to arena/architectures/NPR/networks/resnet.py diff --git a/arena/detectors/NPR/networks/trainer.py b/arena/architectures/NPR/networks/trainer.py similarity index 100% rename from arena/detectors/NPR/networks/trainer.py rename to arena/architectures/NPR/networks/trainer.py diff --git a/arena/detectors/NPR/options/__init__.py b/arena/architectures/NPR/options/__init__.py similarity index 100% rename from arena/detectors/NPR/options/__init__.py rename to arena/architectures/NPR/options/__init__.py diff --git a/arena/detectors/NPR/options/base_options.py b/arena/architectures/NPR/options/base_options.py similarity index 100% rename from arena/detectors/NPR/options/base_options.py rename to arena/architectures/NPR/options/base_options.py diff --git a/arena/detectors/NPR/options/test_options.py b/arena/architectures/NPR/options/test_options.py similarity index 100% rename from arena/detectors/NPR/options/test_options.py rename to arena/architectures/NPR/options/test_options.py diff --git a/arena/detectors/NPR/options/train_options.py b/arena/architectures/NPR/options/train_options.py similarity index 100% rename from arena/detectors/NPR/options/train_options.py rename to arena/architectures/NPR/options/train_options.py diff --git a/arena/detectors/NPR/requirements.txt b/arena/architectures/NPR/requirements.txt similarity index 100% rename from arena/detectors/NPR/requirements.txt rename to arena/architectures/NPR/requirements.txt diff --git a/arena/detectors/NPR/test.py b/arena/architectures/NPR/test.py similarity index 100% rename from arena/detectors/NPR/test.py rename to arena/architectures/NPR/test.py diff --git a/arena/detectors/NPR/train_detector.py b/arena/architectures/NPR/train_detector.py similarity index 100% rename from arena/detectors/NPR/train_detector.py rename to arena/architectures/NPR/train_detector.py diff --git a/arena/detectors/NPR/util/__init__.py b/arena/architectures/NPR/util/__init__.py similarity index 100% rename from arena/detectors/NPR/util/__init__.py rename to arena/architectures/NPR/util/__init__.py diff --git a/arena/detectors/NPR/util/eval.py b/arena/architectures/NPR/util/eval.py similarity index 100% rename from arena/detectors/NPR/util/eval.py rename to arena/architectures/NPR/util/eval.py diff --git a/arena/detectors/NPR/validate.py b/arena/architectures/NPR/validate.py similarity index 100% rename from arena/detectors/NPR/validate.py rename to arena/architectures/NPR/validate.py diff --git a/arena/detectors/deepfake_detectors/unit_tests/__init__.py b/arena/architectures/SPSL/__init__.py similarity index 100% rename from arena/detectors/deepfake_detectors/unit_tests/__init__.py rename to arena/architectures/SPSL/__init__.py diff --git a/arena/architectures/SPSL/config/__init__.py b/arena/architectures/SPSL/config/__init__.py new file mode 100644 index 0000000..676145d --- /dev/null +++ b/arena/architectures/SPSL/config/__init__.py @@ -0,0 +1,7 @@ +import os +import sys +current_file_path = os.path.abspath(__file__) +parent_dir = os.path.dirname(os.path.dirname(current_file_path)) +project_root_dir = os.path.dirname(parent_dir) +sys.path.append(parent_dir) +sys.path.append(project_root_dir) diff --git a/arena/architectures/SPSL/config/constants.py b/arena/architectures/SPSL/config/constants.py new file mode 100644 index 0000000..08d0cac --- /dev/null +++ b/arena/architectures/SPSL/config/constants.py @@ -0,0 +1,13 @@ +import os + +# Path to the directory containing the constants.py file +CONFIGS_DIR = os.path.dirname(os.path.abspath(__file__)) + +# The base directory for related files +BASE_PATH = os.path.abspath(os.path.join(CONFIGS_DIR, "..")) +# Absolute paths for the required files and directories +CONFIG_PATH = os.path.join(CONFIGS_DIR, "spsl.yaml") # Path to the .yaml file +WEIGHTS_DIR = os.path.join(BASE_PATH, "weights/") # Path to pretrained weights directory + +HF_REPO = "bitmind/spsl" +BACKBONE_CKPT = "xception_best.pth" \ No newline at end of file diff --git a/arena/architectures/SPSL/config/spsl.yaml b/arena/architectures/SPSL/config/spsl.yaml new file mode 100644 index 0000000..9ba8f48 --- /dev/null +++ b/arena/architectures/SPSL/config/spsl.yaml @@ -0,0 +1,88 @@ +# log dir +log_dir: /mntcephfs/lab_data/zhiyuanyan/benchmark_results/logs_final/spsl_4frames + +# model setting +pretrained: ../weights/xception_best.pth # path to a pre-trained model, if using one +# pretrained: /home/tianshuoge/resnet34-b627a593.pth # path to a pre-trained model, if using one +model_name: spsl # model name +backbone_name: xception # backbone name + +#backbone setting +backbone_config: + mode: original # shallow_xception + num_classes: 2 + inc: 4 + dropout: false + +# dataset +all_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT, FaceShifter, DeepFakeDetection, Celeb-DF-v1, Celeb-DF-v2, DFDCP, DFDC, DeeperForensics-1.0, UADFV] +train_dataset: [FF-FS] +test_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT] + +compression: c23 # compression-level for videos +train_batchSize: 32 # training batch size +test_batchSize: 32 # test batch size +workers: 8 # number of data loading workers +frame_num: {'train': 4, 'test': 32} # number of frames to use per video in training and testing +resolution: 256 # resolution of output image to network +with_mask: false # whether to include mask information in the input +with_landmark: false # whether to include facial landmark information in the input +save_ckpt: true # whether to save checkpoint +save_feat: true # whether to save features + + +# data augmentation +use_data_augmentation: true # Add this flag to enable/disable data augmentation +data_aug: + flip_prob: 0.5 + rotate_prob: 0.5 + rotate_limit: [-10, 10] + blur_prob: 0.5 + blur_limit: [3, 7] + brightness_prob: 0.5 + brightness_limit: [-0.1, 0.1] + contrast_limit: [-0.1, 0.1] + quality_lower: 40 + quality_upper: 100 + +# mean and std for normalization +mean: [0.5, 0.5, 0.5] +std: [0.5, 0.5, 0.5] + +# optimizer config +optimizer: + # choose between 'adam' and 'sgd' + type: adam + adam: + lr: 0.0002 # learning rate + beta1: 0.9 # beta1 for Adam optimizer + beta2: 0.999 # beta2 for Adam optimizer + eps: 0.00000001 # epsilon for Adam optimizer + weight_decay: 0.0005 # weight decay for regularization + amsgrad: false + sgd: + lr: 0.0002 # learning rate + momentum: 0.9 # momentum for SGD optimizer + weight_decay: 0.0005 # weight decay for regularization + +# training config +lr_scheduler: null # learning rate scheduler +nEpochs: 10 # number of epochs to train for +start_epoch: 0 # manual epoch number (useful for restarts) +save_epoch: 1 # interval epochs for saving models +rec_iter: 100 # interval iterations for recording +logdir: ./logs # folder to output images and logs +manualSeed: 1024 # manual seed for random number generation +save_ckpt: false # whether to save checkpoint + +# loss function +loss_func: cross_entropy # loss function to use +losstype: null + +# metric +metric_scoring: auc # metric for evaluation (auc, acc, eer, ap) + +# cuda + +cuda: true # whether to use CUDA acceleration +cudnn: true # whether to use CuDNN for convolution operations diff --git a/arena/architectures/SPSL/config/train_config.yaml b/arena/architectures/SPSL/config/train_config.yaml new file mode 100644 index 0000000..e154ea6 --- /dev/null +++ b/arena/architectures/SPSL/config/train_config.yaml @@ -0,0 +1,43 @@ +mode: train +lmdb: True +dry_run: false +rgb_dir: './datasets/rgb' +lmdb_dir: './datasets/lmdb' +dataset_json_folder: './preprocessing/dataset_json' +SWA: False +save_avg: True +log_dir: ./logs/training/ +# label settings +label_dict: + # DFD + DFD_fake: 1 + DFD_real: 0 + # FF++ + FaceShifter(FF-real+FF-FH) + FF-SH: 1 + FF-F2F: 1 + FF-DF: 1 + FF-FS: 1 + FF-NT: 1 + FF-FH: 1 + FF-real: 0 + # CelebDF + CelebDFv1_real: 0 + CelebDFv1_fake: 1 + CelebDFv2_real: 0 + CelebDFv2_fake: 1 + # DFDCP + DFDCP_Real: 0 + DFDCP_FakeA: 1 + DFDCP_FakeB: 1 + # DFDC + DFDC_Fake: 1 + DFDC_Real: 0 + # DeeperForensics-1.0 + DF_fake: 1 + DF_real: 0 + # UADFV + UADFV_Fake: 1 + UADFV_Real: 0 + # Roop + roop_Real: 0 + roop_Fake: 1 \ No newline at end of file diff --git a/arena/detectors/UCF/config/xception.yaml b/arena/architectures/SPSL/config/xception.yaml similarity index 100% rename from arena/detectors/UCF/config/xception.yaml rename to arena/architectures/SPSL/config/xception.yaml diff --git a/arena/architectures/SPSL/detectors/__init__.py b/arena/architectures/SPSL/detectors/__init__.py new file mode 100644 index 0000000..d16b773 --- /dev/null +++ b/arena/architectures/SPSL/detectors/__init__.py @@ -0,0 +1,11 @@ +import os +import sys +current_file_path = os.path.abspath(__file__) +parent_dir = os.path.dirname(os.path.dirname(current_file_path)) +project_root_dir = os.path.dirname(parent_dir) +sys.path.append(parent_dir) +sys.path.append(project_root_dir) + +from metrics.registry import DETECTOR + +from .spsl_detector import SpslDetector \ No newline at end of file diff --git a/arena/architectures/SPSL/detectors/base_detector.py b/arena/architectures/SPSL/detectors/base_detector.py new file mode 100644 index 0000000..b240972 --- /dev/null +++ b/arena/architectures/SPSL/detectors/base_detector.py @@ -0,0 +1,71 @@ +# author: Zhiyuan Yan +# email: zhiyuanyan@link.cuhk.edu.cn +# date: 2023-0706 +# description: Abstract Class for the Deepfake Detector + +import abc +import torch +import torch.nn as nn +from typing import Union + +class AbstractDetector(nn.Module, metaclass=abc.ABCMeta): + """ + All deepfake detectors should subclass this class. + """ + def __init__(self, config=None, load_param: Union[bool, str] = False): + """ + config: (dict) + configurations for the model + load_param: (False | True | Path(str)) + False Do not read; True Read the default path; Path Read the required path + """ + super().__init__() + + @abc.abstractmethod + def features(self, data_dict: dict) -> torch.tensor: + """ + Returns the features from the backbone given the input data. + """ + pass + + @abc.abstractmethod + def forward(self, data_dict: dict, inference=False) -> dict: + """ + Forward pass through the model, returning the prediction dictionary. + """ + pass + + @abc.abstractmethod + def classifier(self, features: torch.tensor) -> torch.tensor: + """ + Classifies the features into classes. + """ + pass + + @abc.abstractmethod + def build_backbone(self, config): + """ + Builds the backbone of the model. + """ + pass + + @abc.abstractmethod + def build_loss(self, config): + """ + Builds the loss function for the model. + """ + pass + + @abc.abstractmethod + def get_losses(self, data_dict: dict, pred_dict: dict) -> dict: + """ + Returns the losses for the model. + """ + pass + + @abc.abstractmethod + def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict: + """ + Returns the training metrics for the model. + """ + pass diff --git a/arena/architectures/SPSL/detectors/spsl_detector.py b/arena/architectures/SPSL/detectors/spsl_detector.py new file mode 100644 index 0000000..46536c8 --- /dev/null +++ b/arena/architectures/SPSL/detectors/spsl_detector.py @@ -0,0 +1,228 @@ +""" +Class for the SPSLDetector. + +This module implements the Spatial-Phase Shallow Learning (SPSL) detector +for face forgery detection in the frequency domain. + +Author: Zhiyuan Yan +Email: zhiyuanyan@link.cuhk.edu.cn +Date: 2023-07-06 + +Functions in the Class: +1. __init__: Initialization +2. build_backbone: Backbone-building +3. build_loss: Loss-function-building +4. features: Feature-extraction +5. classifier: Classification +6. get_losses: Loss-computation +7. get_train_metrics: Training-metrics-computation +8. get_test_metrics: Testing-metrics-computation +9. forward: Forward-propagation + +Reference: +@inproceedings{liu2021spatial, + title={Spatial-phase shallow learning: rethinking face forgery detection in frequency domain}, + author={Liu, Honggu and Li, Xiaodan and Zhou, Wenbo and Chen, Yuefeng and He, Yuan and Xue, Hui and Zhang, Weiming and Yu, Nenghai}, + booktitle={Proceedings of the IEEE/CVF conference on computer vision and pattern recognition}, + pages={772--781}, + year={2021} +} + +Note: +To ensure consistency in the comparison with other detectors, we have opted not to utilize +the shallow Xception architecture. Instead, we are employing the original Xception model. +""" + +import os +import datetime +import logging +from typing import Union +from collections import defaultdict + +import numpy as np +from sklearn import metrics +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.nn import DataParallel +from torch.utils.tensorboard import SummaryWriter + +from metrics.base_metrics_class import calculate_metrics_for_train +from .base_detector import AbstractDetector +from detectors import DETECTOR +from networks import BACKBONE +from loss import LOSSFUNC + +logger = logging.getLogger(__name__) + + +@DETECTOR.register_module(module_name='spsl') +class SpslDetector(AbstractDetector): + """ + Spatial-Phase Shallow Learning (SPSL) detector for face forgery detection. + """ + + def __init__(self, config): + """ + Initialize the SPSL detector. + + Args: + config (dict): Configuration dictionary for the detector. + """ + super().__init__() + self.config = config + self.backbone = self.build_backbone(config) + self.loss_func = self.build_loss(config) + + def build_backbone(self, config): + """ + Build the backbone network for the detector. + + Args: + config (dict): Configuration dictionary for the backbone. + + Returns: + nn.Module: The constructed backbone network. + + Note: + This method adapts the pretrained 3-channel (RGB) conv1 layer to a + 4-channel input (RGB + phase) by averaging the pretrained weights + across RGB channels and repeating for the 4th channel. This + preserves pretrained knowledge while accommodating the + additional phase information channel. + """ + backbone_class = BACKBONE[config['backbone_name']] + model_config = config['backbone_config'] + backbone = backbone_class(model_config) + + state_dict = torch.load(config['pretrained']) + for name, weights in state_dict.items(): + if 'pointwise' in name: + state_dict[name] = weights.unsqueeze(-1).unsqueeze(-1) + state_dict = {k: v for k, v in state_dict.items() if 'fc' not in k} + + # Create a new conv1 layer with 4 input channels + backbone.conv1 = nn.Conv2d(4, 32, 3, 2, 0, bias=False) + # Check if 'conv1.weight' exists in the original state_dict + if 'backbone.conv1.weight' in state_dict: + conv1_data = state_dict['backbone.conv1.weight'] + # average across the RGB channels + avg_conv1_data = conv1_data.mean(dim=1, keepdim=True) + # repeat the averaged weights across the 4 new channels + backbone.conv1.weight.data = avg_conv1_data.repeat(1, 4, 1, 1) + return backbone + + def build_loss(self, config): + """ + Build the loss function for the detector. + + Args: + config (dict): Configuration dictionary for the loss function. + + Returns: + nn.Module: The constructed loss function. + """ + loss_class = LOSSFUNC[config['loss_func']] + loss_func = loss_class() + return loss_func + + def features(self, data_dict: dict, phase_fea) -> torch.Tensor: + """ + Extract features from the input data. + + Args: + data_dict (dict): Input data dictionary. + phase_fea (torch.Tensor): Phase features. + + Returns: + torch.Tensor: Extracted features. + """ + features = torch.cat((data_dict['image'], phase_fea), dim=1) + return self.backbone.features(features) + + def classifier(self, features: torch.Tensor) -> torch.Tensor: + """ + Classify the extracted features. + + Args: + features (torch.Tensor): Extracted features. + + Returns: + torch.Tensor: Classification output. + """ + return self.backbone.classifier(features) + + def get_losses(self, data_dict: dict, pred_dict: dict) -> dict: + """ + Compute losses for the detector. + + Args: + data_dict (dict): Input data dictionary. + pred_dict (dict): Prediction dictionary. + + Returns: + dict: Dictionary of computed losses. + """ + label = data_dict['label'] + pred = pred_dict['cls'] + loss = self.loss_func(pred, label) + loss_dict = {'overall': loss} + return loss_dict + + def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict: + """ + Compute training metrics for the detector. + + Args: + data_dict (dict): Input data dictionary. + pred_dict (dict): Prediction dictionary. + + Returns: + dict: Dictionary of computed metrics. + """ + label = data_dict['label'] + pred = pred_dict['cls'] + auc, eer, acc, ap = calculate_metrics_for_train(label.detach(), pred.detach()) + metric_batch_dict = {'acc': acc, 'auc': auc, 'eer': eer, 'ap': ap} + self.video_names = [] + return metric_batch_dict + + def forward(self, data_dict: dict) -> dict: + """ + Forward pass of the detector. + + Args: + data_dict (dict): Input data dictionary. + + Returns: + dict: Dictionary containing prediction results. + """ + phase_fea = self.phase_without_amplitude(data_dict['image']) + features = self.features(data_dict, phase_fea) + pred = self.classifier(features) + prob = torch.softmax(pred, dim=1)[:, 1] + pred_dict = {'cls': pred, 'prob': prob, 'feat': features} + return pred_dict + + def phase_without_amplitude(self, img): + """ + Extract phase information without amplitude from the input image. + + Args: + img (torch.Tensor): Input image tensor. + + Returns: + torch.Tensor: Reconstructed image using phase information. + """ + # Convert to grayscale + gray_img = torch.mean(img, dim=1, keepdim=True) + # Compute the DFT of the input signal + X = torch.fft.fftn(gray_img, dim=(-1, -2)) + # Extract the phase information from the DFT + phase_spectrum = torch.angle(X) + # Create a new complex spectrum with the phase information and zero magnitude + reconstructed_X = torch.exp(1j * phase_spectrum) + # Use the IDFT to obtain the reconstructed signal + reconstructed_x = torch.real(torch.fft.ifftn(reconstructed_X, dim=(-1, -2))) + return reconstructed_x diff --git a/arena/detectors/UCF/logger.py b/arena/architectures/SPSL/logger.py similarity index 100% rename from arena/detectors/UCF/logger.py rename to arena/architectures/SPSL/logger.py diff --git a/arena/detectors/UCF/config/__init__.py b/arena/architectures/SPSL/loss/__init__.py similarity index 100% rename from arena/detectors/UCF/config/__init__.py rename to arena/architectures/SPSL/loss/__init__.py diff --git a/arena/detectors/UCF/loss/abstract_loss_func.py b/arena/architectures/SPSL/loss/abstract_loss_func.py similarity index 100% rename from arena/detectors/UCF/loss/abstract_loss_func.py rename to arena/architectures/SPSL/loss/abstract_loss_func.py diff --git a/arena/detectors/UCF/loss/cross_entropy_loss.py b/arena/architectures/SPSL/loss/cross_entropy_loss.py similarity index 100% rename from arena/detectors/UCF/loss/cross_entropy_loss.py rename to arena/architectures/SPSL/loss/cross_entropy_loss.py diff --git a/arena/architectures/SPSL/metrics/__init__.py b/arena/architectures/SPSL/metrics/__init__.py new file mode 100644 index 0000000..676145d --- /dev/null +++ b/arena/architectures/SPSL/metrics/__init__.py @@ -0,0 +1,7 @@ +import os +import sys +current_file_path = os.path.abspath(__file__) +parent_dir = os.path.dirname(os.path.dirname(current_file_path)) +project_root_dir = os.path.dirname(parent_dir) +sys.path.append(parent_dir) +sys.path.append(project_root_dir) diff --git a/arena/architectures/SPSL/metrics/base_metrics_class.py b/arena/architectures/SPSL/metrics/base_metrics_class.py new file mode 100644 index 0000000..4e3f33b --- /dev/null +++ b/arena/architectures/SPSL/metrics/base_metrics_class.py @@ -0,0 +1,205 @@ +import numpy as np +from sklearn import metrics +from collections import defaultdict +import torch +import torch.nn as nn + + +def get_accracy(output, label): + _, prediction = torch.max(output, 1) # argmax + correct = (prediction == label).sum().item() + accuracy = correct / prediction.size(0) + return accuracy + + +def get_prediction(output, label): + prob = nn.functional.softmax(output, dim=1)[:, 1] + prob = prob.view(prob.size(0), 1) + label = label.view(label.size(0), 1) + #print(prob.size(), label.size()) + datas = torch.cat((prob, label.float()), dim=1) + return datas + + +def calculate_metrics_for_train(label, output): + if output.size(1) == 2: + prob = torch.softmax(output, dim=1)[:, 1] + else: + prob = output + + # Accuracy + _, prediction = torch.max(output, 1) + correct = (prediction == label).sum().item() + accuracy = correct / prediction.size(0) + + # Average Precision + y_true = label.cpu().detach().numpy() + y_pred = prob.cpu().detach().numpy() + ap = metrics.average_precision_score(y_true, y_pred) + + # AUC and EER + try: + fpr, tpr, thresholds = metrics.roc_curve(label.squeeze().cpu().numpy(), + prob.squeeze().cpu().numpy(), + pos_label=1) + except: + # for the case when we only have one sample + return None, None, accuracy, ap + + if np.isnan(fpr[0]) or np.isnan(tpr[0]): + # for the case when all the samples within a batch is fake/real + auc, eer = None, None + else: + auc = metrics.auc(fpr, tpr) + fnr = 1 - tpr + eer = fpr[np.nanargmin(np.absolute((fnr - fpr)))] + + return auc, eer, accuracy, ap + + +# ------------ compute average metrics of batches--------------------- +class Metrics_batch(): + def __init__(self): + self.tprs = [] + self.mean_fpr = np.linspace(0, 1, 100) + self.aucs = [] + self.eers = [] + self.aps = [] + + self.correct = 0 + self.total = 0 + self.losses = [] + + def update(self, label, output): + acc = self._update_acc(label, output) + if output.size(1) == 2: + prob = torch.softmax(output, dim=1)[:, 1] + else: + prob = output + #label = 1-label + #prob = torch.softmax(output, dim=1)[:, 1] + auc, eer = self._update_auc(label, prob) + ap = self._update_ap(label, prob) + + return acc, auc, eer, ap + + def _update_auc(self, lab, prob): + fpr, tpr, thresholds = metrics.roc_curve(lab.squeeze().cpu().numpy(), + prob.squeeze().cpu().numpy(), + pos_label=1) + if np.isnan(fpr[0]) or np.isnan(tpr[0]): + return -1, -1 + + auc = metrics.auc(fpr, tpr) + interp_tpr = np.interp(self.mean_fpr, fpr, tpr) + interp_tpr[0] = 0.0 + self.tprs.append(interp_tpr) + self.aucs.append(auc) + + # return auc + + # EER + fnr = 1 - tpr + eer = fpr[np.nanargmin(np.absolute((fnr - fpr)))] + self.eers.append(eer) + + return auc, eer + + def _update_acc(self, lab, output): + _, prediction = torch.max(output, 1) # argmax + correct = (prediction == lab).sum().item() + accuracy = correct / prediction.size(0) + # self.accs.append(accuracy) + self.correct = self.correct+correct + self.total = self.total+lab.size(0) + return accuracy + + def _update_ap(self, label, prob): + y_true = label.cpu().detach().numpy() + y_pred = prob.cpu().detach().numpy() + ap = metrics.average_precision_score(y_true,y_pred) + self.aps.append(ap) + + return np.mean(ap) + + def get_mean_metrics(self): + mean_acc, std_acc = self.correct/self.total, 0 + mean_auc, std_auc = self._mean_auc() + mean_err, std_err = np.mean(self.eers), np.std(self.eers) + mean_ap, std_ap = np.mean(self.aps), np.std(self.aps) + + return {'acc':mean_acc, 'auc':mean_auc, 'eer':mean_err, 'ap':mean_ap} + + def _mean_auc(self): + mean_tpr = np.mean(self.tprs, axis=0) + mean_tpr[-1] = 1.0 + mean_auc = metrics.auc(self.mean_fpr, mean_tpr) + std_auc = np.std(self.aucs) + return mean_auc, std_auc + + def clear(self): + self.tprs.clear() + self.aucs.clear() + # self.accs.clear() + self.correct=0 + self.total=0 + self.eers.clear() + self.aps.clear() + self.losses.clear() + + +# ------------ compute average metrics of all data --------------------- +class Metrics_all(): + def __init__(self): + self.probs = [] + self.labels = [] + self.correct = 0 + self.total = 0 + + def store(self, label, output): + prob = torch.softmax(output, dim=1)[:, 1] + _, prediction = torch.max(output, 1) # argmax + correct = (prediction == label).sum().item() + self.correct += correct + self.total += label.size(0) + self.labels.append(label.squeeze().cpu().numpy()) + self.probs.append(prob.squeeze().cpu().numpy()) + + def get_metrics(self): + y_pred = np.concatenate(self.probs) + y_true = np.concatenate(self.labels) + # auc + fpr, tpr, thresholds = metrics.roc_curve(y_true,y_pred,pos_label=1) + auc = metrics.auc(fpr, tpr) + # eer + fnr = 1 - tpr + eer = fpr[np.nanargmin(np.absolute((fnr - fpr)))] + # ap + ap = metrics.average_precision_score(y_true,y_pred) + # acc + acc = self.correct / self.total + return {'acc':acc, 'auc':auc, 'eer':eer, 'ap':ap} + + def clear(self): + self.probs.clear() + self.labels.clear() + self.correct = 0 + self.total = 0 + + +# only used to record a series of scalar value +class Recorder: + def __init__(self): + self.sum = 0 + self.num = 0 + def update(self, item, num=1): + if item is not None: + self.sum += item * num + self.num += num + def average(self): + if self.num == 0: + return None + return self.sum/self.num + def clear(self): + self.sum = 0 + self.num = 0 diff --git a/arena/architectures/SPSL/metrics/registry.py b/arena/architectures/SPSL/metrics/registry.py new file mode 100644 index 0000000..86e256c --- /dev/null +++ b/arena/architectures/SPSL/metrics/registry.py @@ -0,0 +1,20 @@ +class Registry(object): + def __init__(self): + self.data = {} + + def register_module(self, module_name=None): + def _register(cls): + name = module_name + if module_name is None: + name = cls.__name__ + self.data[name] = cls + return cls + return _register + + def __getitem__(self, key): + return self.data[key] + +BACKBONE = Registry() +DETECTOR = Registry() +TRAINER = Registry() +LOSSFUNC = Registry() diff --git a/arena/architectures/SPSL/metrics/utils.py b/arena/architectures/SPSL/metrics/utils.py new file mode 100644 index 0000000..606d35c --- /dev/null +++ b/arena/architectures/SPSL/metrics/utils.py @@ -0,0 +1,93 @@ +from sklearn import metrics +import numpy as np + + +def parse_metric_for_print(metric_dict): + if metric_dict is None: + return "\n" + str = "\n" + str += "================================ Each dataset best metric ================================ \n" + for key, value in metric_dict.items(): + if key != 'avg': + str= str+ f"| {key}: " + for k,v in value.items(): + str = str + f" {k}={v} " + str= str+ "| \n" + else: + str += "============================================================================================= \n" + str += "================================== Average best metric ====================================== \n" + avg_dict = value + for avg_key, avg_value in avg_dict.items(): + if avg_key == 'dataset_dict': + for key,value in avg_value.items(): + str = str + f"| {key}: {value} | \n" + else: + str = str + f"| avg {avg_key}: {avg_value} | \n" + str += "=============================================================================================" + return str + + +def get_test_metrics(y_pred, y_true, img_names): + def get_video_metrics(image, pred, label): + result_dict = {} + new_label = [] + new_pred = [] + # print(image[0]) + # print(pred.shape) + # print(label.shape) + for item in np.transpose(np.stack((image, pred, label)), (1, 0)): + + s = item[0] + if '\\' in s: + parts = s.split('\\') + else: + parts = s.split('/') + a = parts[-2] + b = parts[-1] + + if a not in result_dict: + result_dict[a] = [] + + result_dict[a].append(item) + image_arr = list(result_dict.values()) + + for video in image_arr: + pred_sum = 0 + label_sum = 0 + leng = 0 + for frame in video: + pred_sum += float(frame[1]) + label_sum += int(frame[2]) + leng += 1 + new_pred.append(pred_sum / leng) + new_label.append(int(label_sum / leng)) + fpr, tpr, thresholds = metrics.roc_curve(new_label, new_pred) + v_auc = metrics.auc(fpr, tpr) + fnr = 1 - tpr + v_eer = fpr[np.nanargmin(np.absolute((fnr - fpr)))] + return v_auc, v_eer + + + y_pred = y_pred.squeeze() + # For UCF, where labels for different manipulations are not consistent. + y_true[y_true >= 1] = 1 + # auc + fpr, tpr, thresholds = metrics.roc_curve(y_true, y_pred, pos_label=1) + auc = metrics.auc(fpr, tpr) + # eer + fnr = 1 - tpr + eer = fpr[np.nanargmin(np.absolute((fnr - fpr)))] + # ap + ap = metrics.average_precision_score(y_true, y_pred) + # acc + prediction_class = (y_pred > 0.5).astype(int) + correct = (prediction_class == np.clip(y_true, a_min=0, a_max=1)).sum().item() + acc = correct / len(prediction_class) + if type(img_names[0]) is not list: + # calculate video-level auc for the frame-level methods. + v_auc, _ = get_video_metrics(img_names, y_pred, y_true) + else: + # video-level methods + v_auc=auc + + return {'acc': acc, 'auc': auc, 'eer': eer, 'ap': ap, 'pred': y_pred, 'video_auc': v_auc, 'label': y_true} diff --git a/arena/detectors/UCF/networks/__init__.py b/arena/architectures/SPSL/networks/__init__.py similarity index 100% rename from arena/detectors/UCF/networks/__init__.py rename to arena/architectures/SPSL/networks/__init__.py diff --git a/arena/detectors/UCF/networks/xception.py b/arena/architectures/SPSL/networks/xception.py similarity index 100% rename from arena/detectors/UCF/networks/xception.py rename to arena/architectures/SPSL/networks/xception.py diff --git a/arena/detectors/UCF/optimizor/LinearLR.py b/arena/architectures/SPSL/optimizor/LinearLR.py similarity index 100% rename from arena/detectors/UCF/optimizor/LinearLR.py rename to arena/architectures/SPSL/optimizor/LinearLR.py diff --git a/arena/detectors/UCF/optimizor/SAM.py b/arena/architectures/SPSL/optimizor/SAM.py similarity index 100% rename from arena/detectors/UCF/optimizor/SAM.py rename to arena/architectures/SPSL/optimizor/SAM.py diff --git a/arena/architectures/SPSL/train.py b/arena/architectures/SPSL/train.py new file mode 100644 index 0000000..b06be8a --- /dev/null +++ b/arena/architectures/SPSL/train.py @@ -0,0 +1,323 @@ +# author: Zhiyuan Yan +# email: zhiyuanyan@link.cuhk.edu.cn +# date: 2023-03-30 +# description: training code. + +import os +import argparse +from os.path import join +import cv2 +import random +import datetime +import time +import yaml +from tqdm import tqdm +import numpy as np +from datetime import timedelta +from copy import deepcopy +from PIL import Image as pil_image + +import torch +import torch.nn as nn +import torch.nn.parallel +import torch.backends.cudnn as cudnn +import torch.utils.data +import torch.optim as optim +from torch.utils.data.distributed import DistributedSampler +import torch.distributed as dist + +from optimizor.SAM import SAM +from optimizor.LinearLR import LinearDecayLR + +from trainer.trainer import Trainer +from detectors import DETECTOR +from dataset import * +from metrics.utils import parse_metric_for_print +from logger import create_logger, RankFilter + + +parser = argparse.ArgumentParser(description='Process some paths.') +parser.add_argument('--detector_path', type=str, + default='/data/home/zhiyuanyan/DeepfakeBenchv2/training/config/detector/sbi.yaml', + help='path to detector YAML file') +parser.add_argument("--train_dataset", nargs="+") +parser.add_argument("--test_dataset", nargs="+") +parser.add_argument('--no-save_ckpt', dest='save_ckpt', action='store_false', default=True) +parser.add_argument('--no-save_feat', dest='save_feat', action='store_false', default=True) +parser.add_argument("--ddp", action='store_true', default=False) +parser.add_argument('--local_rank', type=int, default=0) +parser.add_argument('--task_target', type=str, default="", help='specify the target of current training task') +args = parser.parse_args() +torch.cuda.set_device(args.local_rank) + + +def init_seed(config): + if config['manualSeed'] is None: + config['manualSeed'] = random.randint(1, 10000) + random.seed(config['manualSeed']) + if config['cuda']: + torch.manual_seed(config['manualSeed']) + torch.cuda.manual_seed_all(config['manualSeed']) + + +def prepare_training_data(config): + # Only use the blending dataset class in training + if 'dataset_type' in config and config['dataset_type'] == 'blend': + if config['model_name'] == 'facexray': + train_set = FFBlendDataset(config) + elif config['model_name'] == 'fwa': + train_set = FWABlendDataset(config) + elif config['model_name'] == 'sbi': + train_set = SBIDataset(config, mode='train') + elif config['model_name'] == 'lsda': + train_set = LSDADataset(config, mode='train') + else: + raise NotImplementedError( + 'Only facexray, fwa, sbi, and lsda are currently supported for blending dataset' + ) + elif 'dataset_type' in config and config['dataset_type'] == 'pair': + train_set = pairDataset(config, mode='train') # Only use the pair dataset class in training + elif 'dataset_type' in config and config['dataset_type'] == 'iid': + train_set = IIDDataset(config, mode='train') + elif 'dataset_type' in config and config['dataset_type'] == 'I2G': + train_set = I2GDataset(config, mode='train') + elif 'dataset_type' in config and config['dataset_type'] == 'lrl': + train_set = LRLDataset(config, mode='train') + else: + train_set = DeepfakeAbstractBaseDataset( + config=config, + mode='train', + ) + if config['model_name'] == 'lsda': + from dataset.lsda_dataset import CustomSampler + custom_sampler = CustomSampler(num_groups=2*360, n_frame_per_vid=config['frame_num']['train'], batch_size=config['train_batchSize'], videos_per_group=5) + train_data_loader = \ + torch.utils.data.DataLoader( + dataset=train_set, + batch_size=config['train_batchSize'], + num_workers=int(config['workers']), + sampler=custom_sampler, + collate_fn=train_set.collate_fn, + ) + elif config['ddp']: + sampler = DistributedSampler(train_set) + train_data_loader = \ + torch.utils.data.DataLoader( + dataset=train_set, + batch_size=config['train_batchSize'], + num_workers=int(config['workers']), + collate_fn=train_set.collate_fn, + sampler=sampler + ) + else: + train_data_loader = \ + torch.utils.data.DataLoader( + dataset=train_set, + batch_size=config['train_batchSize'], + shuffle=True, + num_workers=int(config['workers']), + collate_fn=train_set.collate_fn, + ) + return train_data_loader + + +def prepare_testing_data(config): + def get_test_data_loader(config, test_name): + # update the config dictionary with the specific testing dataset + config = config.copy() # create a copy of config to avoid altering the original one + config['test_dataset'] = test_name # specify the current test dataset + if not config.get('dataset_type', None) == 'lrl': + test_set = DeepfakeAbstractBaseDataset( + config=config, + mode='test', + ) + else: + test_set = LRLDataset( + config=config, + mode='test', + ) + + test_data_loader = \ + torch.utils.data.DataLoader( + dataset=test_set, + batch_size=config['test_batchSize'], + shuffle=False, + num_workers=int(config['workers']), + collate_fn=test_set.collate_fn, + drop_last = (test_name=='DeepFakeDetection'), + ) + + return test_data_loader + + test_data_loaders = {} + for one_test_name in config['test_dataset']: + test_data_loaders[one_test_name] = get_test_data_loader(config, one_test_name) + return test_data_loaders + + +def choose_optimizer(model, config): + opt_name = config['optimizer']['type'] + if opt_name == 'sgd': + optimizer = optim.SGD( + params=model.parameters(), + lr=config['optimizer'][opt_name]['lr'], + momentum=config['optimizer'][opt_name]['momentum'], + weight_decay=config['optimizer'][opt_name]['weight_decay'] + ) + return optimizer + elif opt_name == 'adam': + optimizer = optim.Adam( + params=model.parameters(), + lr=config['optimizer'][opt_name]['lr'], + weight_decay=config['optimizer'][opt_name]['weight_decay'], + betas=(config['optimizer'][opt_name]['beta1'], config['optimizer'][opt_name]['beta2']), + eps=config['optimizer'][opt_name]['eps'], + amsgrad=config['optimizer'][opt_name]['amsgrad'], + ) + return optimizer + elif opt_name == 'sam': + optimizer = SAM( + model.parameters(), + optim.SGD, + lr=config['optimizer'][opt_name]['lr'], + momentum=config['optimizer'][opt_name]['momentum'], + ) + else: + raise NotImplementedError('Optimizer {} is not implemented'.format(config['optimizer'])) + return optimizer + + +def choose_scheduler(config, optimizer): + if config['lr_scheduler'] is None: + return None + elif config['lr_scheduler'] == 'step': + scheduler = optim.lr_scheduler.StepLR( + optimizer, + step_size=config['lr_step'], + gamma=config['lr_gamma'], + ) + return scheduler + elif config['lr_scheduler'] == 'cosine': + scheduler = optim.lr_scheduler.CosineAnnealingLR( + optimizer, + T_max=config['lr_T_max'], + eta_min=config['lr_eta_min'], + ) + return scheduler + elif config['lr_scheduler'] == 'linear': + scheduler = LinearDecayLR( + optimizer, + config['nEpochs'], + int(config['nEpochs']/4), + ) + else: + raise NotImplementedError('Scheduler {} is not implemented'.format(config['lr_scheduler'])) + + +def choose_metric(config): + metric_scoring = config['metric_scoring'] + if metric_scoring not in ['eer', 'auc', 'acc', 'ap']: + raise NotImplementedError('metric {} is not implemented'.format(metric_scoring)) + return metric_scoring + + +def main(): + # parse options and load config + with open(args.detector_path, 'r') as f: + config = yaml.safe_load(f) + with open('./training/config/train_config.yaml', 'r') as f: + config2 = yaml.safe_load(f) + if 'label_dict' in config: + config2['label_dict']=config['label_dict'] + config.update(config2) + config['local_rank']=args.local_rank + if config['dry_run']: + config['nEpochs'] = 0 + config['save_feat']=False + # If arguments are provided, they will overwrite the yaml settings + if args.train_dataset: + config['train_dataset'] = args.train_dataset + if args.test_dataset: + config['test_dataset'] = args.test_dataset + config['save_ckpt'] = args.save_ckpt + config['save_feat'] = args.save_feat + if config['lmdb']: + config['dataset_json_folder'] = 'preprocessing/dataset_json_v3' + # create logger + timenow=datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S') + task_str = f"_{config['task_target']}" if config['task_target'] is not None else "" + logger_path = os.path.join( + config['log_dir'], + config['model_name'] + task_str + '_' + timenow + ) + os.makedirs(logger_path, exist_ok=True) + logger = create_logger(os.path.join(logger_path, 'training.log')) + logger.info('Save log to {}'.format(logger_path)) + config['ddp']= args.ddp + # print configuration + logger.info("--------------- Configuration ---------------") + params_string = "Parameters: \n" + for key, value in config.items(): + params_string += "{}: {}".format(key, value) + "\n" + logger.info(params_string) + + # init seed + init_seed(config) + + # set cudnn benchmark if needed + if config['cudnn']: + cudnn.benchmark = True + if config['ddp']: + # dist.init_process_group(backend='gloo') + dist.init_process_group( + backend='nccl', + timeout=timedelta(minutes=30) + ) + logger.addFilter(RankFilter(0)) + # prepare the training data loader + train_data_loader = prepare_training_data(config) + + # prepare the testing data loader + test_data_loaders = prepare_testing_data(config) + + # prepare the model (detector) + model_class = DETECTOR[config['model_name']] + model = model_class(config) + + # prepare the optimizer + optimizer = choose_optimizer(model, config) + + # prepare the scheduler + scheduler = choose_scheduler(config, optimizer) + + # prepare the metric + metric_scoring = choose_metric(config) + + # prepare the trainer + trainer = Trainer(config, model, optimizer, scheduler, logger, metric_scoring) + + # start training + for epoch in range(config['start_epoch'], config['nEpochs'] + 1): + trainer.model.epoch = epoch + best_metric = trainer.train_epoch( + epoch=epoch, + train_data_loader=train_data_loader, + test_data_loaders=test_data_loaders, + ) + if best_metric is not None: + logger.info(f"===> Epoch[{epoch}] end with testing {metric_scoring}: {parse_metric_for_print(best_metric)}!") + logger.info("Stop Training on best Testing metric {}".format(parse_metric_for_print(best_metric))) + # update + if 'svdd' in config['model_name']: + model.update_R(epoch) + if scheduler is not None: + scheduler.step() + + # close the tensorboard writers + for writer in trainer.writers.values(): + writer.close() + + + +if __name__ == '__main__': + main() diff --git a/arena/architectures/SPSL/trainer/__init__.py b/arena/architectures/SPSL/trainer/__init__.py new file mode 100644 index 0000000..5326782 --- /dev/null +++ b/arena/architectures/SPSL/trainer/__init__.py @@ -0,0 +1,9 @@ +import os +import sys +current_file_path = os.path.abspath(__file__) +parent_dir = os.path.dirname(os.path.dirname(current_file_path)) +project_root_dir = os.path.dirname(parent_dir) +sys.path.append(parent_dir) +sys.path.append(project_root_dir) + +from metrics.registry import TRAINER \ No newline at end of file diff --git a/arena/architectures/SPSL/trainer/base_trainer.py b/arena/architectures/SPSL/trainer/base_trainer.py new file mode 100644 index 0000000..e140275 --- /dev/null +++ b/arena/architectures/SPSL/trainer/base_trainer.py @@ -0,0 +1,50 @@ +import datetime +from copy import deepcopy +from abc import ABC, abstractmethod + + +class BaseTrainer(ABC): + """ + """ + + def __init__( + self, + config, + model, + optimizer, + scheduler, + writer, + ): + # check if all the necessary components are implemented + if config is None or model is None or optimizer is None or scheduler is None or writer is None: + raise NotImplementedError("config, model, optimizier, scheduler, and tensorboard writer must be implemented") + + self.config = config + self.model = model + self.optimizer = optimizer + self.scheduler = scheduler + self.writer = writer + + @abstractmethod + def speed_up(self): + pass + + @abstractmethod + def setTrain(self): + pass + + @abstractmethod + def setEval(self): + pass + + @abstractmethod + def load_ckpt(self, model_path): + pass + + @abstractmethod + def save_ckpt(self, dataset, epoch, iters, best=False): + pass + + @abstractmethod + def inference(self, data_dict): + pass diff --git a/arena/architectures/SPSL/trainer/trainer.py b/arena/architectures/SPSL/trainer/trainer.py new file mode 100644 index 0000000..471cddc --- /dev/null +++ b/arena/architectures/SPSL/trainer/trainer.py @@ -0,0 +1,463 @@ +# author: Zhiyuan Yan +# email: zhiyuanyan@link.cuhk.edu.cn +# date: 2023-03-30 +# description: trainer +import os +import sys +current_file_path = os.path.abspath(__file__) +parent_dir = os.path.dirname(os.path.dirname(current_file_path)) +project_root_dir = os.path.dirname(parent_dir) +sys.path.append(parent_dir) +sys.path.append(project_root_dir) + +import pickle +import datetime +import logging +import numpy as np +from copy import deepcopy +from collections import defaultdict +from tqdm import tqdm +import time +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.nn import DataParallel +from torch.utils.tensorboard import SummaryWriter +from metrics.base_metrics_class import Recorder +from torch.optim.swa_utils import AveragedModel, SWALR +from torch import distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP +from sklearn import metrics +from metrics.utils import get_test_metrics + +FFpp_pool=['FaceForensics++','FF-DF','FF-F2F','FF-FS','FF-NT']# +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +class Trainer(object): + def __init__( + self, + config, + model, + optimizer, + scheduler, + logger, + metric_scoring='auc', + time_now = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S'), + swa_model=None + ): + # check if all the necessary components are implemented + if config is None or model is None or optimizer is None or logger is None: + raise ValueError("config, model, optimizier, logger, and tensorboard writer must be implemented") + + self.config = config + self.model = model + self.optimizer = optimizer + self.scheduler = scheduler + self.swa_model = swa_model + self.writers = {} # dict to maintain different tensorboard writers for each dataset and metric + self.logger = logger + self.metric_scoring = metric_scoring + # maintain the best metric of all epochs + self.best_metrics_all_time = defaultdict( + lambda: defaultdict(lambda: float('-inf') + if self.metric_scoring != 'eer' else float('inf')) + ) + self.speed_up() # move model to GPU + + # get current time + self.timenow = time_now + # create directory path + if 'task_target' not in config: + self.log_dir = os.path.join( + self.config['log_dir'], + self.config['model_name'] + '_' + self.timenow + ) + else: + task_str = f"_{config['task_target']}" if config['task_target'] is not None else "" + self.log_dir = os.path.join( + self.config['log_dir'], + self.config['model_name'] + task_str + '_' + self.timenow + ) + os.makedirs(self.log_dir, exist_ok=True) + + def get_writer(self, phase, dataset_key, metric_key): + writer_key = f"{phase}-{dataset_key}-{metric_key}" + if writer_key not in self.writers: + # update directory path + writer_path = os.path.join( + self.log_dir, + phase, + dataset_key, + metric_key, + "metric_board" + ) + os.makedirs(writer_path, exist_ok=True) + # update writers dictionary + self.writers[writer_key] = SummaryWriter(writer_path) + return self.writers[writer_key] + + + def speed_up(self): + self.model.to(device) + self.model.device = device + if self.config['ddp'] == True: + num_gpus = torch.cuda.device_count() + print(f'avai gpus: {num_gpus}') + # local_rank=[i for i in range(0,num_gpus)] + self.model = DDP(self.model, device_ids=[self.config['local_rank']],find_unused_parameters=True, output_device=self.config['local_rank']) + #self.optimizer = nn.DataParallel(self.optimizer, device_ids=[int(os.environ['LOCAL_RANK'])]) + + def setTrain(self): + self.model.train() + self.train = True + + def setEval(self): + self.model.eval() + self.train = False + + def load_ckpt(self, model_path): + if os.path.isfile(model_path): + saved = torch.load(model_path, map_location='cpu') + suffix = model_path.split('.')[-1] + if suffix == 'p': + self.model.load_state_dict(saved.state_dict()) + else: + self.model.load_state_dict(saved) + self.logger.info('Model found in {}'.format(model_path)) + else: + raise NotImplementedError( + "=> no model found at '{}'".format(model_path)) + + def save_ckpt(self, phase, dataset_key,ckpt_info=None): + save_dir = os.path.join(self.log_dir, phase, dataset_key) + os.makedirs(save_dir, exist_ok=True) + ckpt_name = f"ckpt_best.pth" + save_path = os.path.join(save_dir, ckpt_name) + if self.config['ddp'] == True: + torch.save(self.model.state_dict(), save_path) + else: + if 'svdd' in self.config['model_name']: + torch.save({'R': self.model.R, + 'c': self.model.c, + 'state_dict': self.model.state_dict(),}, save_path) + else: + torch.save(self.model.state_dict(), save_path) + self.logger.info(f"Checkpoint saved to {save_path}, current ckpt is {ckpt_info}") + + def save_swa_ckpt(self): + save_dir = self.log_dir + os.makedirs(save_dir, exist_ok=True) + ckpt_name = f"swa.pth" + save_path = os.path.join(save_dir, ckpt_name) + torch.save(self.swa_model.state_dict(), save_path) + self.logger.info(f"SWA Checkpoint saved to {save_path}") + + + def save_feat(self, phase, fea, dataset_key): + save_dir = os.path.join(self.log_dir, phase, dataset_key) + os.makedirs(save_dir, exist_ok=True) + features = fea + feat_name = f"feat_best.npy" + save_path = os.path.join(save_dir, feat_name) + np.save(save_path, features) + self.logger.info(f"Feature saved to {save_path}") + + def save_data_dict(self, phase, data_dict, dataset_key): + save_dir = os.path.join(self.log_dir, phase, dataset_key) + os.makedirs(save_dir, exist_ok=True) + file_path = os.path.join(save_dir, f'data_dict_{phase}.pickle') + with open(file_path, 'wb') as file: + pickle.dump(data_dict, file) + self.logger.info(f"data_dict saved to {file_path}") + + def save_metrics(self, phase, metric_one_dataset, dataset_key): + save_dir = os.path.join(self.log_dir, phase, dataset_key) + os.makedirs(save_dir, exist_ok=True) + file_path = os.path.join(save_dir, 'metric_dict_best.pickle') + with open(file_path, 'wb') as file: + pickle.dump(metric_one_dataset, file) + self.logger.info(f"Metrics saved to {file_path}") + + def train_step(self,data_dict): + if self.config['optimizer']['type']=='sam': + for i in range(2): + predictions = self.model(data_dict) + losses = self.model.get_losses(data_dict, predictions) + if i == 0: + pred_first = predictions + losses_first = losses + self.optimizer.zero_grad() + losses['overall'].backward() + if i == 0: + self.optimizer.first_step(zero_grad=True) + else: + self.optimizer.second_step(zero_grad=True) + return losses_first, pred_first + else: + + predictions = self.model(data_dict) + if type(self.model) is DDP: + losses = self.model.module.get_losses(data_dict, predictions) + else: + losses = self.model.get_losses(data_dict, predictions) + self.optimizer.zero_grad() + losses['overall'].backward() + self.optimizer.step() + + + return losses,predictions + + + def train_epoch( + self, + epoch, + train_data_loader, + test_data_loaders=None, + ): + + self.logger.info("===> Epoch[{}] start!".format(epoch)) + if epoch>=1: + times_per_epoch = 2 + else: + times_per_epoch = 1 + + + #times_per_epoch=4 + + test_step = len(train_data_loader) // times_per_epoch # test 10 times per epoch + step_cnt = epoch * len(train_data_loader) + + # save the training data_dict + data_dict = train_data_loader.dataset.data_dict + self.save_data_dict('train', data_dict, ','.join(self.config['train_dataset'])) + # define training recorder + train_recorder_loss = defaultdict(Recorder) + train_recorder_metric = defaultdict(Recorder) + + for iteration, data_dict in tqdm(enumerate(train_data_loader),total=len(train_data_loader)): + self.setTrain() + # more elegant and more scalable way of moving data to GPU + for key in data_dict.keys(): + if data_dict[key]!=None and key!='name': + data_dict[key]=data_dict[key].cuda() + + losses,predictions=self.train_step(data_dict) + + # update learning rate + + if 'SWA' in self.config and self.config['SWA'] and epoch>self.config['swa_start']: + self.swa_model.update_parameters(self.model) + + # compute training metric for each batch data + if type(self.model) is DDP: + batch_metrics = self.model.module.get_train_metrics(data_dict, predictions) + else: + batch_metrics = self.model.get_train_metrics(data_dict, predictions) + + # store data by recorder + ## store metric + for name, value in batch_metrics.items(): + train_recorder_metric[name].update(value) + ## store loss + for name, value in losses.items(): + train_recorder_loss[name].update(value) + + # run tensorboard to visualize the training process + if iteration % 300 == 0 and self.config['local_rank']==0: + if self.config['SWA'] and (epoch>self.config['swa_start'] or self.config['dry_run']): + self.scheduler.step() + # info for loss + loss_str = f"Iter: {step_cnt} " + for k, v in train_recorder_loss.items(): + v_avg = v.average() + if v_avg == None: + loss_str += f"training-loss, {k}: not calculated" + continue + loss_str += f"training-loss, {k}: {v_avg} " + # tensorboard-1. loss + writer = self.get_writer('train', ','.join(self.config['train_dataset']), k) + writer.add_scalar(f'train_loss/{k}', v_avg, global_step=step_cnt) + self.logger.info(loss_str) + # info for metric + metric_str = f"Iter: {step_cnt} " + for k, v in train_recorder_metric.items(): + v_avg = v.average() + if v_avg == None: + metric_str += f"training-metric, {k}: not calculated " + continue + metric_str += f"training-metric, {k}: {v_avg} " + # tensorboard-2. metric + writer = self.get_writer('train', ','.join(self.config['train_dataset']), k) + writer.add_scalar(f'train_metric/{k}', v_avg, global_step=step_cnt) + self.logger.info(metric_str) + + + + # clear recorder. + # Note we only consider the current 300 samples for computing batch-level loss/metric + for name, recorder in train_recorder_loss.items(): # clear loss recorder + recorder.clear() + for name, recorder in train_recorder_metric.items(): # clear metric recorder + recorder.clear() + + # run test + if (step_cnt+1) % test_step == 0: + if test_data_loaders is not None and (not self.config['ddp'] ): + self.logger.info("===> Test start!") + test_best_metric = self.test_epoch( + epoch, + iteration, + test_data_loaders, + step_cnt, + ) + elif test_data_loaders is not None and (self.config['ddp'] and dist.get_rank() == 0): + self.logger.info("===> Test start!") + test_best_metric = self.test_epoch( + epoch, + iteration, + test_data_loaders, + step_cnt, + ) + else: + test_best_metric = None + + # total_end_time = time.time() + # total_elapsed_time = total_end_time - total_start_time + # print("总花费的时间: {:.2f} 秒".format(total_elapsed_time)) + step_cnt += 1 + return test_best_metric + + def get_respect_acc(self,prob,label): + pred = np.where(prob > 0.5, 1, 0) + judge = (pred == label) + zero_num = len(label) - np.count_nonzero(label) + acc_fake = np.count_nonzero(judge[zero_num:]) / len(judge[zero_num:]) + acc_real = np.count_nonzero(judge[:zero_num]) / len(judge[:zero_num]) + return acc_real,acc_fake + + def test_one_dataset(self, data_loader): + # define test recorder + test_recorder_loss = defaultdict(Recorder) + prediction_lists = [] + feature_lists=[] + label_lists = [] + for i, data_dict in tqdm(enumerate(data_loader),total=len(data_loader)): + # get data + if 'label_spe' in data_dict: + data_dict.pop('label_spe') # remove the specific label + data_dict['label'] = torch.where(data_dict['label']!=0, 1, 0) # fix the label to 0 and 1 only + # move data to GPU elegantly + for key in data_dict.keys(): + if data_dict[key]!=None: + data_dict[key]=data_dict[key].cuda() + # model forward without considering gradient computation + predictions = self.inference(data_dict) + label_lists += list(data_dict['label'].cpu().detach().numpy()) + prediction_lists += list(predictions['prob'].cpu().detach().numpy()) + feature_lists += list(predictions['feat'].cpu().detach().numpy()) + if type(self.model) is not AveragedModel: + # compute all losses for each batch data + if type(self.model) is DDP: + losses = self.model.module.get_losses(data_dict, predictions) + else: + losses = self.model.get_losses(data_dict, predictions) + + # store data by recorder + for name, value in losses.items(): + test_recorder_loss[name].update(value) + + return test_recorder_loss, np.array(prediction_lists), np.array(label_lists),np.array(feature_lists) + + def save_best(self,epoch,iteration,step,losses_one_dataset_recorder,key,metric_one_dataset): + best_metric = self.best_metrics_all_time[key].get(self.metric_scoring, + float('-inf') if self.metric_scoring != 'eer' else float( + 'inf')) + # Check if the current score is an improvement + improved = (metric_one_dataset[self.metric_scoring] > best_metric) if self.metric_scoring != 'eer' else ( + metric_one_dataset[self.metric_scoring] < best_metric) + if improved: + # Update the best metric + self.best_metrics_all_time[key][self.metric_scoring] = metric_one_dataset[self.metric_scoring] + if key == 'avg': + self.best_metrics_all_time[key]['dataset_dict'] = metric_one_dataset['dataset_dict'] + # Save checkpoint, feature, and metrics if specified in config + if self.config['save_ckpt'] and key not in FFpp_pool: + self.save_ckpt('test', key, f"{epoch}+{iteration}") + self.save_metrics('test', metric_one_dataset, key) + if losses_one_dataset_recorder is not None: + # info for each dataset + loss_str = f"dataset: {key} step: {step} " + for k, v in losses_one_dataset_recorder.items(): + writer = self.get_writer('test', key, k) + v_avg = v.average() + if v_avg == None: + print(f'{k} is not calculated') + continue + # tensorboard-1. loss + writer.add_scalar(f'test_losses/{k}', v_avg, global_step=step) + loss_str += f"testing-loss, {k}: {v_avg} " + self.logger.info(loss_str) + # tqdm.write(loss_str) + metric_str = f"dataset: {key} step: {step} " + for k, v in metric_one_dataset.items(): + if k == 'pred' or k == 'label' or k=='dataset_dict': + continue + metric_str += f"testing-metric, {k}: {v} " + # tensorboard-2. metric + writer = self.get_writer('test', key, k) + writer.add_scalar(f'test_metrics/{k}', v, global_step=step) + if 'pred' in metric_one_dataset: + acc_real, acc_fake = self.get_respect_acc(metric_one_dataset['pred'], metric_one_dataset['label']) + metric_str += f'testing-metric, acc_real:{acc_real}; acc_fake:{acc_fake}' + writer.add_scalar(f'test_metrics/acc_real', acc_real, global_step=step) + writer.add_scalar(f'test_metrics/acc_fake', acc_fake, global_step=step) + self.logger.info(metric_str) + def test_epoch(self, epoch, iteration, test_data_loaders, step): + # set model to eval mode + self.setEval() + + # define test recorder + losses_all_datasets = {} + metrics_all_datasets = {} + best_metrics_per_dataset = defaultdict(dict) # best metric for each dataset, for each metric + avg_metric = {'acc': 0, 'auc': 0, 'eer': 0, 'ap': 0,'video_auc': 0,'dataset_dict':{}} + # testing for all test data + keys = test_data_loaders.keys() + for key in keys: + # save the testing data_dict + data_dict = test_data_loaders[key].dataset.data_dict + self.save_data_dict('test', data_dict, key) + + # compute loss for each dataset + losses_one_dataset_recorder, predictions_nps, label_nps, feature_nps = self.test_one_dataset(test_data_loaders[key]) + # print(f'stack len:{predictions_nps.shape};{label_nps.shape};{len(data_dict["image"])}') + losses_all_datasets[key] = losses_one_dataset_recorder + metric_one_dataset=get_test_metrics(y_pred=predictions_nps,y_true=label_nps,img_names=data_dict['image']) + for metric_name, value in metric_one_dataset.items(): + if metric_name in avg_metric: + avg_metric[metric_name]+=value + avg_metric['dataset_dict'][key] = metric_one_dataset[self.metric_scoring] + if type(self.model) is AveragedModel: + metric_str = f"Iter Final for SWA: " + for k, v in metric_one_dataset.items(): + metric_str += f"testing-metric, {k}: {v} " + self.logger.info(metric_str) + continue + self.save_best(epoch,iteration,step,losses_one_dataset_recorder,key,metric_one_dataset) + + if len(keys)>0 and self.config.get('save_avg',False): + # calculate avg value + for key in avg_metric: + if key != 'dataset_dict': + avg_metric[key] /= len(keys) + self.save_best(epoch, iteration, step, None, 'avg', avg_metric) + + self.logger.info('===> Test Done!') + return self.best_metrics_all_time # return all types of mean metrics for determining the best ckpt + + @torch.no_grad() + def inference(self, data_dict): + predictions = self.model(data_dict, inference=True) + return predictions diff --git a/arena/detectors/UCF/README.md b/arena/architectures/UCF/README.md similarity index 100% rename from arena/detectors/UCF/README.md rename to arena/architectures/UCF/README.md diff --git a/arena/detectors/UCF/metrics/__init__.py b/arena/architectures/UCF/config/__init__.py similarity index 100% rename from arena/detectors/UCF/metrics/__init__.py rename to arena/architectures/UCF/config/__init__.py diff --git a/arena/detectors/UCF/config/constants.py b/arena/architectures/UCF/config/constants.py similarity index 100% rename from arena/detectors/UCF/config/constants.py rename to arena/architectures/UCF/config/constants.py diff --git a/arena/detectors/UCF/config/pretrained_config.yaml b/arena/architectures/UCF/config/pretrained_config.yaml similarity index 100% rename from arena/detectors/UCF/config/pretrained_config.yaml rename to arena/architectures/UCF/config/pretrained_config.yaml diff --git a/arena/detectors/UCF/config/pretrained_face_config.yaml b/arena/architectures/UCF/config/pretrained_face_config.yaml similarity index 100% rename from arena/detectors/UCF/config/pretrained_face_config.yaml rename to arena/architectures/UCF/config/pretrained_face_config.yaml diff --git a/arena/detectors/UCF/config/train_config.yaml b/arena/architectures/UCF/config/train_config.yaml similarity index 100% rename from arena/detectors/UCF/config/train_config.yaml rename to arena/architectures/UCF/config/train_config.yaml diff --git a/arena/detectors/UCF/config/ucf.yaml b/arena/architectures/UCF/config/ucf.yaml similarity index 100% rename from arena/detectors/UCF/config/ucf.yaml rename to arena/architectures/UCF/config/ucf.yaml diff --git a/arena/architectures/UCF/config/xception.yaml b/arena/architectures/UCF/config/xception.yaml new file mode 100644 index 0000000..9198f69 --- /dev/null +++ b/arena/architectures/UCF/config/xception.yaml @@ -0,0 +1,86 @@ +# log dir +log_dir: /data/home/zhiyuanyan/DeepfakeBench/logs/testing_bench + +# model setting +pretrained: /data/home/zhiyuanyan/DeepfakeBench/training/pretrained/xception-b5690688.pth # path to a pre-trained model, if using one +model_name: xception # model name +backbone_name: xception # backbone name + +#backbone setting +backbone_config: + mode: original + num_classes: 2 + inc: 3 + dropout: false + +# dataset +all_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT, FaceShifter, DeepFakeDetection, Celeb-DF-v1, Celeb-DF-v2, DFDCP, DFDC, DeeperForensics-1.0, UADFV] +train_dataset: [FaceForensics++] +test_dataset: [FaceForensics++, DeepFakeDetection] + +compression: c23 # compression-level for videos +train_batchSize: 32 # training batch size +test_batchSize: 32 # test batch size +workers: 8 # number of data loading workers +frame_num: {'train': 32, 'test': 32} # number of frames to use per video in training and testing +resolution: 256 # resolution of output image to network +with_mask: false # whether to include mask information in the input +with_landmark: false # whether to include facial landmark information in the input + + +# data augmentation +use_data_augmentation: true # Add this flag to enable/disable data augmentation +data_aug: + flip_prob: 0.5 + rotate_prob: 0.0 + rotate_limit: [-10, 10] + blur_prob: 0.5 + blur_limit: [3, 7] + brightness_prob: 0.5 + brightness_limit: [-0.1, 0.1] + contrast_limit: [-0.1, 0.1] + quality_lower: 40 + quality_upper: 100 + +# mean and std for normalization +mean: [0.5, 0.5, 0.5] +std: [0.5, 0.5, 0.5] + +# optimizer config +optimizer: + # choose between 'adam' and 'sgd' + type: adam + adam: + lr: 0.0002 # learning rate + beta1: 0.9 # beta1 for Adam optimizer + beta2: 0.999 # beta2 for Adam optimizer + eps: 0.00000001 # epsilon for Adam optimizer + weight_decay: 0.0005 # weight decay for regularization + amsgrad: false + sgd: + lr: 0.0002 # learning rate + momentum: 0.9 # momentum for SGD optimizer + weight_decay: 0.0005 # weight decay for regularization + +# training config +lr_scheduler: null # learning rate scheduler +nEpochs: 10 # number of epochs to train for +start_epoch: 0 # manual epoch number (useful for restarts) +save_epoch: 1 # interval epochs for saving models +rec_iter: 100 # interval iterations for recording +logdir: ./logs # folder to output images and logs +manualSeed: 1024 # manual seed for random number generation +save_ckpt: true # whether to save checkpoint +save_feat: true # whether to save features + +# loss function +loss_func: cross_entropy # loss function to use +losstype: null + +# metric +metric_scoring: auc # metric for evaluation (auc, acc, eer, ap) + +# cuda + +cuda: true # whether to use CUDA acceleration +cudnn: true # whether to use CuDNN for convolution operations diff --git a/arena/detectors/UCF/detectors/__init__.py b/arena/architectures/UCF/detectors/__init__.py similarity index 100% rename from arena/detectors/UCF/detectors/__init__.py rename to arena/architectures/UCF/detectors/__init__.py diff --git a/arena/detectors/UCF/detectors/base_detector.py b/arena/architectures/UCF/detectors/base_detector.py similarity index 100% rename from arena/detectors/UCF/detectors/base_detector.py rename to arena/architectures/UCF/detectors/base_detector.py diff --git a/arena/detectors/UCF/detectors/ucf_detector.py b/arena/architectures/UCF/detectors/ucf_detector.py similarity index 99% rename from arena/detectors/UCF/detectors/ucf_detector.py rename to arena/architectures/UCF/detectors/ucf_detector.py index 21487ab..7f22043 100644 --- a/arena/detectors/UCF/detectors/ucf_detector.py +++ b/arena/architectures/UCF/detectors/ucf_detector.py @@ -44,7 +44,7 @@ from metrics.base_metrics_class import calculate_metrics_for_train from .base_detector import AbstractDetector -from arena.detectors.UCF.detectors import DETECTOR +from arena.architectures.UCF.detectors import DETECTOR from networks import BACKBONE from loss import LOSSFUNC diff --git a/arena/architectures/UCF/logger.py b/arena/architectures/UCF/logger.py new file mode 100644 index 0000000..9ee268d --- /dev/null +++ b/arena/architectures/UCF/logger.py @@ -0,0 +1,36 @@ +import os +import logging + +import torch.distributed as dist + +class RankFilter(logging.Filter): + def __init__(self, rank): + super().__init__() + self.rank = rank + + def filter(self, record): + return dist.get_rank() == self.rank + +def create_logger(log_path): + # Create log path + if os.path.isdir(os.path.dirname(log_path)): + os.makedirs(os.path.dirname(log_path), exist_ok=True) + + # Create logger object + logger = logging.getLogger() + logger.setLevel(logging.INFO) + # Create file handler and set the formatter + fh = logging.FileHandler(log_path) + formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') + fh.setFormatter(formatter) + + # Add the file handler to the logger + logger.addHandler(fh) + + # Add a stream handler to print to console + sh = logging.StreamHandler() + sh.setLevel(logging.INFO) # Set logging level for stream handler + sh.setFormatter(formatter) + logger.addHandler(sh) + + return logger \ No newline at end of file diff --git a/arena/detectors/UCF/loss/__init__.py b/arena/architectures/UCF/loss/__init__.py similarity index 100% rename from arena/detectors/UCF/loss/__init__.py rename to arena/architectures/UCF/loss/__init__.py diff --git a/arena/architectures/UCF/loss/abstract_loss_func.py b/arena/architectures/UCF/loss/abstract_loss_func.py new file mode 100644 index 0000000..45d3324 --- /dev/null +++ b/arena/architectures/UCF/loss/abstract_loss_func.py @@ -0,0 +1,17 @@ +import torch.nn as nn + +class AbstractLossClass(nn.Module): + """Abstract class for loss functions.""" + def __init__(self): + super(AbstractLossClass, self).__init__() + + def forward(self, pred, label): + """ + Args: + pred: prediction of the model + label: ground truth label + + Return: + loss: loss value + """ + raise NotImplementedError('Each subclass should implement the forward method.') diff --git a/arena/detectors/UCF/loss/contrastive_regularization.py b/arena/architectures/UCF/loss/contrastive_regularization.py similarity index 100% rename from arena/detectors/UCF/loss/contrastive_regularization.py rename to arena/architectures/UCF/loss/contrastive_regularization.py diff --git a/arena/architectures/UCF/loss/cross_entropy_loss.py b/arena/architectures/UCF/loss/cross_entropy_loss.py new file mode 100644 index 0000000..efa7123 --- /dev/null +++ b/arena/architectures/UCF/loss/cross_entropy_loss.py @@ -0,0 +1,26 @@ +import torch.nn as nn +from .abstract_loss_func import AbstractLossClass +from metrics.registry import LOSSFUNC + + +@LOSSFUNC.register_module(module_name="cross_entropy") +class CrossEntropyLoss(AbstractLossClass): + def __init__(self): + super().__init__() + self.loss_fn = nn.CrossEntropyLoss() + + def forward(self, inputs, targets): + """ + Computes the cross-entropy loss. + + Args: + inputs: A PyTorch tensor of size (batch_size, num_classes) containing the predicted scores. + targets: A PyTorch tensor of size (batch_size) containing the ground-truth class indices. + + Returns: + A scalar tensor representing the cross-entropy loss. + """ + # Compute the cross-entropy loss + loss = self.loss_fn(inputs, targets) + + return loss \ No newline at end of file diff --git a/arena/detectors/UCF/loss/l1_loss.py b/arena/architectures/UCF/loss/l1_loss.py similarity index 100% rename from arena/detectors/UCF/loss/l1_loss.py rename to arena/architectures/UCF/loss/l1_loss.py diff --git a/arena/architectures/UCF/metrics/__init__.py b/arena/architectures/UCF/metrics/__init__.py new file mode 100644 index 0000000..1ed99d5 --- /dev/null +++ b/arena/architectures/UCF/metrics/__init__.py @@ -0,0 +1,7 @@ +import os +import sys +current_file_path = os.path.abspath(__file__) +parent_dir = os.path.dirname(os.path.dirname(current_file_path)) +project_root_dir = os.path.dirname(parent_dir) +sys.path.append(parent_dir) +sys.path.append(project_root_dir) \ No newline at end of file diff --git a/arena/detectors/UCF/metrics/base_metrics_class.py b/arena/architectures/UCF/metrics/base_metrics_class.py similarity index 100% rename from arena/detectors/UCF/metrics/base_metrics_class.py rename to arena/architectures/UCF/metrics/base_metrics_class.py diff --git a/arena/detectors/UCF/metrics/registry.py b/arena/architectures/UCF/metrics/registry.py similarity index 100% rename from arena/detectors/UCF/metrics/registry.py rename to arena/architectures/UCF/metrics/registry.py diff --git a/arena/detectors/UCF/metrics/utils.py b/arena/architectures/UCF/metrics/utils.py similarity index 100% rename from arena/detectors/UCF/metrics/utils.py rename to arena/architectures/UCF/metrics/utils.py diff --git a/arena/architectures/UCF/networks/__init__.py b/arena/architectures/UCF/networks/__init__.py new file mode 100644 index 0000000..f9be255 --- /dev/null +++ b/arena/architectures/UCF/networks/__init__.py @@ -0,0 +1,11 @@ +import os +import sys +current_file_path = os.path.abspath(__file__) +parent_dir = os.path.dirname(os.path.dirname(current_file_path)) +project_root_dir = os.path.dirname(parent_dir) +sys.path.append(parent_dir) +sys.path.append(project_root_dir) + +from metrics.registry import BACKBONE + +from .xception import Xception \ No newline at end of file diff --git a/arena/architectures/UCF/networks/xception.py b/arena/architectures/UCF/networks/xception.py new file mode 100644 index 0000000..410345c --- /dev/null +++ b/arena/architectures/UCF/networks/xception.py @@ -0,0 +1,285 @@ +''' +# author: Zhiyuan Yan +# email: zhiyuanyan@link.cuhk.edu.cn +# date: 2023-0706 + +The code is mainly modified from GitHub link below: +https://github.com/ondyari/FaceForensics/blob/master/classification/network/xception.py +''' + +import os +import argparse +import logging + +import math +import torch +# import pretrainedmodels +import torch.nn as nn +import torch.nn.functional as F + +import torch.utils.model_zoo as model_zoo +from torch.nn import init +from typing import Union +from metrics.registry import BACKBONE + +logger = logging.getLogger(__name__) + + + +class SeparableConv2d(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=False): + super(SeparableConv2d, self).__init__() + + self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size, + stride, padding, dilation, groups=in_channels, bias=bias) + self.pointwise = nn.Conv2d( + in_channels, out_channels, 1, 1, 0, 1, 1, bias=bias) + + def forward(self, x): + x = self.conv1(x) + x = self.pointwise(x) + return x + + +class Block(nn.Module): + def __init__(self, in_filters, out_filters, reps, strides=1, start_with_relu=True, grow_first=True): + super(Block, self).__init__() + + if out_filters != in_filters or strides != 1: + self.skip = nn.Conv2d(in_filters, out_filters, + 1, stride=strides, bias=False) + self.skipbn = nn.BatchNorm2d(out_filters) + else: + self.skip = None + + self.relu = nn.ReLU(inplace=True) + rep = [] + + filters = in_filters + if grow_first: # whether the number of filters grows first + rep.append(self.relu) + rep.append(SeparableConv2d(in_filters, out_filters, + 3, stride=1, padding=1, bias=False)) + rep.append(nn.BatchNorm2d(out_filters)) + filters = out_filters + + for i in range(reps-1): + rep.append(self.relu) + rep.append(SeparableConv2d(filters, filters, + 3, stride=1, padding=1, bias=False)) + rep.append(nn.BatchNorm2d(filters)) + + if not grow_first: + rep.append(self.relu) + rep.append(SeparableConv2d(in_filters, out_filters, + 3, stride=1, padding=1, bias=False)) + rep.append(nn.BatchNorm2d(out_filters)) + + if not start_with_relu: + rep = rep[1:] + else: + rep[0] = nn.ReLU(inplace=False) + + if strides != 1: + rep.append(nn.MaxPool2d(3, strides, 1)) + self.rep = nn.Sequential(*rep) + + def forward(self, inp): + x = self.rep(inp) + + if self.skip is not None: + skip = self.skip(inp) + skip = self.skipbn(skip) + else: + skip = inp + + x += skip + return x + +def add_gaussian_noise(ins, mean=0, stddev=0.2): + noise = ins.data.new(ins.size()).normal_(mean, stddev) + return ins + noise + + +@BACKBONE.register_module(module_name="xception") +class Xception(nn.Module): + """ + Xception optimized for the ImageNet dataset, as specified in + https://arxiv.org/pdf/1610.02357.pdf + """ + + def __init__(self, xception_config): + """ Constructor + Args: + xception_config: configuration file with the dict format + """ + super(Xception, self).__init__() + self.num_classes = xception_config["num_classes"] + self.mode = xception_config["mode"] + inc = xception_config["inc"] + dropout = xception_config["dropout"] + + # Entry flow + self.conv1 = nn.Conv2d(inc, 32, 3, 2, 0, bias=False) + + self.bn1 = nn.BatchNorm2d(32) + self.relu = nn.ReLU(inplace=True) + + self.conv2 = nn.Conv2d(32, 64, 3, bias=False) + self.bn2 = nn.BatchNorm2d(64) + # do relu here + + self.block1 = Block( + 64, 128, 2, 2, start_with_relu=False, grow_first=True) + self.block2 = Block( + 128, 256, 2, 2, start_with_relu=True, grow_first=True) + self.block3 = Block( + 256, 728, 2, 2, start_with_relu=True, grow_first=True) + + # middle flow + self.block4 = Block( + 728, 728, 3, 1, start_with_relu=True, grow_first=True) + self.block5 = Block( + 728, 728, 3, 1, start_with_relu=True, grow_first=True) + self.block6 = Block( + 728, 728, 3, 1, start_with_relu=True, grow_first=True) + self.block7 = Block( + 728, 728, 3, 1, start_with_relu=True, grow_first=True) + + self.block8 = Block( + 728, 728, 3, 1, start_with_relu=True, grow_first=True) + self.block9 = Block( + 728, 728, 3, 1, start_with_relu=True, grow_first=True) + self.block10 = Block( + 728, 728, 3, 1, start_with_relu=True, grow_first=True) + self.block11 = Block( + 728, 728, 3, 1, start_with_relu=True, grow_first=True) + + # Exit flow + self.block12 = Block( + 728, 1024, 2, 2, start_with_relu=True, grow_first=False) + + self.conv3 = SeparableConv2d(1024, 1536, 3, 1, 1) + self.bn3 = nn.BatchNorm2d(1536) + + # do relu here + self.conv4 = SeparableConv2d(1536, 2048, 3, 1, 1) + self.bn4 = nn.BatchNorm2d(2048) + # used for iid + final_channel = 2048 + if self.mode == 'adjust_channel_iid': + final_channel = 512 + self.mode = 'adjust_channel' + self.last_linear = nn.Linear(final_channel, self.num_classes) + if dropout: + self.last_linear = nn.Sequential( + nn.Dropout(p=dropout), + nn.Linear(final_channel, self.num_classes) + ) + + self.adjust_channel = nn.Sequential( + nn.Conv2d(2048, 512, 1, 1), + nn.BatchNorm2d(512), + nn.ReLU(inplace=False), + ) + + def fea_part1_0(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + + return x + + def fea_part1_1(self, x): + + x = self.conv2(x) + x = self.bn2(x) + x = self.relu(x) + + return x + + def fea_part1(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + + x = self.conv2(x) + x = self.bn2(x) + x = self.relu(x) + + return x + + def fea_part2(self, x): + x = self.block1(x) + x = self.block2(x) + x = self.block3(x) + + return x + + def fea_part3(self, x): + if self.mode == "shallow_xception": + return x + else: + x = self.block4(x) + x = self.block5(x) + x = self.block6(x) + x = self.block7(x) + return x + + def fea_part4(self, x): + if self.mode == "shallow_xception": + x = self.block12(x) + else: + x = self.block8(x) + x = self.block9(x) + x = self.block10(x) + x = self.block11(x) + x = self.block12(x) + return x + + def fea_part5(self, x): + x = self.conv3(x) + x = self.bn3(x) + x = self.relu(x) + + x = self.conv4(x) + x = self.bn4(x) + + return x + + def features(self, input): + x = self.fea_part1(input) + + x = self.fea_part2(x) + x = self.fea_part3(x) + x = self.fea_part4(x) + + x = self.fea_part5(x) + + if self.mode == 'adjust_channel': + x = self.adjust_channel(x) + + return x + + def classifier(self, features,id_feat=None): + # for iid + if self.mode == 'adjust_channel': + x = features + else: + x = self.relu(features) + + if len(x.shape) == 4: + x = F.adaptive_avg_pool2d(x, (1, 1)) + x = x.view(x.size(0), -1) + self.last_emb = x + # for iid + if id_feat!=None: + out = self.last_linear(x-id_feat) + else: + out = self.last_linear(x) + return out + + def forward(self, input): + x = self.features(input) + out = self.classifier(x) + return out, x diff --git a/arena/architectures/UCF/optimizor/LinearLR.py b/arena/architectures/UCF/optimizor/LinearLR.py new file mode 100644 index 0000000..80bc70d --- /dev/null +++ b/arena/architectures/UCF/optimizor/LinearLR.py @@ -0,0 +1,20 @@ +import torch +from torch.optim import SGD +from torch.optim.lr_scheduler import _LRScheduler + +class LinearDecayLR(_LRScheduler): + def __init__(self, optimizer, n_epoch, start_decay, last_epoch=-1): + self.start_decay=start_decay + self.n_epoch=n_epoch + super(LinearDecayLR, self).__init__(optimizer, last_epoch) + + def get_lr(self): + last_epoch = self.last_epoch + n_epoch=self.n_epoch + b_lr=self.base_lrs[0] + start_decay=self.start_decay + if last_epoch>start_decay: + lr=b_lr-b_lr/(n_epoch-start_decay)*(last_epoch-start_decay) + else: + lr=b_lr + return [lr] \ No newline at end of file diff --git a/arena/architectures/UCF/optimizor/SAM.py b/arena/architectures/UCF/optimizor/SAM.py new file mode 100644 index 0000000..7b8d1dc --- /dev/null +++ b/arena/architectures/UCF/optimizor/SAM.py @@ -0,0 +1,77 @@ +# borrowed from + +import torch + +import torch +import torch.nn as nn + +def disable_running_stats(model): + def _disable(module): + if isinstance(module, nn.BatchNorm2d): + module.backup_momentum = module.momentum + module.momentum = 0 + + model.apply(_disable) + +def enable_running_stats(model): + def _enable(module): + if isinstance(module, nn.BatchNorm2d) and hasattr(module, "backup_momentum"): + module.momentum = module.backup_momentum + + model.apply(_enable) + +class SAM(torch.optim.Optimizer): + def __init__(self, params, base_optimizer, rho=0.05, **kwargs): + assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}" + + defaults = dict(rho=rho, **kwargs) + super(SAM, self).__init__(params, defaults) + + self.base_optimizer = base_optimizer(self.param_groups, **kwargs) + self.param_groups = self.base_optimizer.param_groups + + @torch.no_grad() + def first_step(self, zero_grad=False): + grad_norm = self._grad_norm() + for group in self.param_groups: + scale = group["rho"] / (grad_norm + 1e-12) + + for p in group["params"]: + if p.grad is None: continue + e_w = p.grad * scale.to(p) + p.add_(e_w) # climb to the local maximum "w + e(w)" + self.state[p]["e_w"] = e_w + + if zero_grad: self.zero_grad() + + @torch.no_grad() + def second_step(self, zero_grad=False): + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: continue + p.sub_(self.state[p]["e_w"]) # get back to "w" from "w + e(w)" + + self.base_optimizer.step() # do the actual "sharpness-aware" update + + if zero_grad: self.zero_grad() + + @torch.no_grad() + def step(self, closure=None): + assert closure is not None, "Sharpness Aware Minimization requires closure, but it was not provided" + closure = torch.enable_grad()(closure) # the closure should do a full forward-backward pass + + self.first_step(zero_grad=True) + closure() + self.second_step() + + def _grad_norm(self): + shared_device = self.param_groups[0]["params"][0].device # put everything on the same device, in case of model parallelism + norm = torch.norm( + torch.stack([ + p.grad.norm(p=2).to(shared_device) + for group in self.param_groups for p in group["params"] + if p.grad is not None + ]), + p=2 + ) + return norm \ No newline at end of file diff --git a/arena/detectors/UCF/train_detector.py b/arena/architectures/UCF/train_detector.py similarity index 100% rename from arena/detectors/UCF/train_detector.py rename to arena/architectures/UCF/train_detector.py diff --git a/arena/detectors/UCF/trainer/trainer.py b/arena/architectures/UCF/trainer/trainer.py similarity index 100% rename from arena/detectors/UCF/trainer/trainer.py rename to arena/architectures/UCF/trainer/trainer.py diff --git a/arena/detectors/README.md b/arena/detectors/README.md index c6b4c2a..368a3f3 100644 --- a/arena/detectors/README.md +++ b/arena/detectors/README.md @@ -1,39 +1,46 @@ -## Base Miners +## Base Detectors -The `base_miner/` directory facilitates the training, orchestration, and deployment of modular and highly customizable deepfake detectors. +The `detectors/` directory facilitates the training, orchestration, and deployment of modular and highly customizable deepfake detectors. We broadly define **detector** as an algorithm that either employs a single model or orchestrates multiple models to perform the binary real-or-AI inference task. These **models** can be any algorithm that processes an image to determine its classification. This includes not only pretrained machine learning architectures, but also heuristic and statistical modeling frameworks. -## Our Base Miner Detector: Content-Aware Model Orchestration (CAMO) +## Adding Your Own Detector -Read about [CAMO (Content Aware Model Orchestration)](https://bitmindlabs.notion.site/CAMO-Content-Aware-Model-Orchestration-CAMO-Framework-for-Deepfake-Detection-43ef46a0f9de403abec7a577a45cd075), our generalized framework for creating “hard mixture of expert” detectors. +If you're interested in creating and adding your own detector to the DFD Arena framework, please refer to our [Tutorial on Adding a New Deepfake Detector](tutorial.md). This guide provides step-by-step instructions on how to integrate your custom detector into our system. -- **Latest Iteration**: The most performant iteration of `class CAMODetector(DeepfakeDetector)` used in our base miner `neurons/miner.py` incorporates a `GatingMechanism(Gate)` that routes to a fine-tuned face expert model and generalist model with the `UCF` architecture. +## Our Base Detector: Content-Aware Model Orchestration (CAMO) -## Directory Structure +Read about [CAMO (Content Aware Model Orchestration)](https://bitmindlabs.notion.site/CAMO-Content-Aware-Model-Orchestration-CAMO-Framework-for-Deepfake-Detection-43ef46a0f9de403abec7a577a45cd075), our generalized framework for creating "hard mixture of expert" detectors. -### 1. Architectures and Training -- **UCF/** and **NPR/** +- **Latest Iteration**: The most performant iteration of `class CAMODetector(DeepfakeDetector)` used in our base miner `neurons/miner.py` incorporates a `GatingMechanism(Gate)` that routes to a fine-tuned face expert model and generalist model with the `UCF` architecture. -These folders contain model architectures and training loops for `UCF (ICCV 2023)` and `NPR (CVPR 2024)`, adapted to use curated and preprocessed training datasets on our [BitMind Huggingface](https://huggingface.co/bitmind). +## Directory Structure -### 2. deepfake_detectors/ +### 1. detectors/ The modular structure for detectors used in the miner neuron is defined here, through `DeepfakeDetector` abstract base class and subclass implementations. -- **deepfake_detectors/** contains: +- **detectors/** contains: - **configs/**: YAML configuration files to load detector instance attributes, including any pretrained model weights. - **Abstract Base Class**: A foundational class that outlines the standard structure for implementing detectors. - **Detector Subclasses**: Specialized detector implementations that can be dynamically loaded and managed based on configuration. -The `DeepfakeDetector design` allows for high configurability and extension. +The `DeepfakeDetector` design allows for high configurability and extension. + +### 2. architectures/ +This folder contains model-specific files for various architectures used in deepfake detection: + +- **UCF/**: Contains model architecture and training loops for `UCF (ICCV 2023)`. +- **NPR/**: Contains model architecture and training loops for `NPR (CVPR 2024)`. + +Both are adapted to use curated and preprocessed training datasets on our [BitMind Huggingface](https://huggingface.co/bitmind). -### 3. gating_mechanisms/ -Similar to `deepfake_detectors/`, this folder contains abstract base classes and implementations of `Gate`s that are used to handle content-aware preprocessing and routing. This is especially useful for multi-agent detection systems, such as the `DeepfakeDetector` subclass `CAMODetector` in `deepfake_detectors/camo_detector.py`. +### 3. gates/ +This folder contains abstract base classes and implementations of `Gate`s that are used to handle content-aware preprocessing and routing. This is especially useful for multi-agent detection systems, such as the `DeepfakeDetector` subclass `CAMODetector`. - **Abstract Gate Class**: A base class for implementing image content gating. - **Gate Subclasses**: These subclasses define specific gating mechanisms responsible for routing inputs to appropriate expert detectors or preprocessing steps based on content characteristics. This is useful for multi-detector or mixture-of-expert detector setups. ### 4. registry.py -The `registry.py` file is responsible for managing the creation of detectors and gates using a **Factory Method** design pattern. It auto-registers all `DeepfakeDetector` and `Gate` subclasses from their subfolders to respective `Registry` constants, making it simple to instantiate detectors and gates dynamically based on predefined constants. +The `registry.py` file is responsible for managing the creation of detectors and gates using a **Factory Method** design pattern. It auto-registers all `DeepfakeDetector` and `Gate` subclasses from their respective folders to respective `Registry` constants, making it simple to instantiate detectors and gates dynamically based on predefined constants. - **Factory Pattern**: Ensures a clean, maintainable, and scalable method for creating instances of detectors and gating mechanisms. - **Auto-Registration**: Automatically registers all available detector and gate subclasses, enabling a flexible and extensible system. diff --git a/arena/detectors/__init__.py b/arena/detectors/__init__.py index 7748609..3653112 100644 --- a/arena/detectors/__init__.py +++ b/arena/detectors/__init__.py @@ -1,3 +1,8 @@ from .registry import DETECTOR_REGISTRY, GATE_REGISTRY -from .deepfake_detectors import NPRDetector, UCFDetector, CAMODetector -from .gating_mechanisms import FaceGate, GatingMechanism \ No newline at end of file +from .gates import FaceGate, GatingMechanism +from .deepfake_detector import DeepfakeDetector +from .npr_detector import NPRDetector +from .ucf_detector import UCFDetector +from .camo_detector import CAMODetector +from .spsl_detector import SPSLDetector + diff --git a/arena/detectors/deepfake_detectors/camo_detector.py b/arena/detectors/camo_detector.py similarity index 95% rename from arena/detectors/deepfake_detectors/camo_detector.py rename to arena/detectors/camo_detector.py index 435147a..f77c1c1 100644 --- a/arena/detectors/deepfake_detectors/camo_detector.py +++ b/arena/detectors/camo_detector.py @@ -2,9 +2,10 @@ import yaml import torch from PIL import Image + from arena.detectors.registry import DETECTOR_REGISTRY -from arena.detectors.gating_mechanisms import GatingMechanism -from arena.detectors.deepfake_detectors import DeepfakeDetector +from arena.detectors.gates import GatingMechanism +from arena.detectors.deepfake_detector import DeepfakeDetector @DETECTOR_REGISTRY.register_module(module_name='CAMO') diff --git a/arena/detectors/deepfake_detectors/configs/camo.yaml b/arena/detectors/configs/camo.yaml similarity index 100% rename from arena/detectors/deepfake_detectors/configs/camo.yaml rename to arena/detectors/configs/camo.yaml diff --git a/arena/detectors/deepfake_detectors/configs/npr.yaml b/arena/detectors/configs/npr.yaml similarity index 100% rename from arena/detectors/deepfake_detectors/configs/npr.yaml rename to arena/detectors/configs/npr.yaml diff --git a/arena/detectors/configs/spsl.yaml b/arena/detectors/configs/spsl.yaml new file mode 100644 index 0000000..b75416f --- /dev/null +++ b/arena/detectors/configs/spsl.yaml @@ -0,0 +1,3 @@ +hf_repo: 'bitmind/spsl' # Hugging Face repository for downloading model files +weights: 'spsl_best.pth' # model checkpoint in HuggingFaces +train_config: 'spsl.yaml' \ No newline at end of file diff --git a/arena/detectors/deepfake_detectors/configs/ucf.yaml b/arena/detectors/configs/ucf.yaml similarity index 100% rename from arena/detectors/deepfake_detectors/configs/ucf.yaml rename to arena/detectors/configs/ucf.yaml diff --git a/arena/detectors/deepfake_detectors/configs/ucf_face.yaml b/arena/detectors/configs/ucf_face.yaml similarity index 100% rename from arena/detectors/deepfake_detectors/configs/ucf_face.yaml rename to arena/detectors/configs/ucf_face.yaml diff --git a/arena/detectors/deepfake_detectors/deepfake_detector.py b/arena/detectors/deepfake_detector.py similarity index 100% rename from arena/detectors/deepfake_detectors/deepfake_detector.py rename to arena/detectors/deepfake_detector.py diff --git a/arena/detectors/deepfake_detectors/__init__.py b/arena/detectors/deepfake_detectors/__init__.py deleted file mode 100644 index 08e0636..0000000 --- a/arena/detectors/deepfake_detectors/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .deepfake_detector import DeepfakeDetector -from .npr_detector import NPRDetector -from .ucf_detector import UCFDetector -from .camo_detector import CAMODetector \ No newline at end of file diff --git a/arena/detectors/gating_mechanisms/__init__.py b/arena/detectors/gates/__init__.py similarity index 100% rename from arena/detectors/gating_mechanisms/__init__.py rename to arena/detectors/gates/__init__.py diff --git a/arena/detectors/gating_mechanisms/face_gate.py b/arena/detectors/gates/face_gate.py similarity index 91% rename from arena/detectors/gating_mechanisms/face_gate.py rename to arena/detectors/gates/face_gate.py index b869ff0..df7d368 100644 --- a/arena/detectors/gating_mechanisms/face_gate.py +++ b/arena/detectors/gates/face_gate.py @@ -3,10 +3,10 @@ import numpy as np import dlib -from arena.detectors.gating_mechanisms import Gate -from arena.detectors.UCF.config.constants import DLIB_FACE_PREDICTOR_PATH +from arena.detectors.gates import Gate from arena.detectors import GATE_REGISTRY -from arena.detectors.gating_mechanisms.face_utils import get_face_landmarks, align_and_crop_face +from arena.detectors.gates.face_utils import get_face_landmarks, align_and_crop_face +from arena.architectures.UCF.config.constants import DLIB_FACE_PREDICTOR_PATH @GATE_REGISTRY.register_module(module_name='FACE') diff --git a/arena/detectors/gating_mechanisms/face_utils.py b/arena/detectors/gates/face_utils.py similarity index 100% rename from arena/detectors/gating_mechanisms/face_utils.py rename to arena/detectors/gates/face_utils.py diff --git a/arena/detectors/gating_mechanisms/gate.py b/arena/detectors/gates/gate.py similarity index 100% rename from arena/detectors/gating_mechanisms/gate.py rename to arena/detectors/gates/gate.py diff --git a/arena/detectors/gating_mechanisms/gating_mechanism.py b/arena/detectors/gates/gating_mechanism.py similarity index 100% rename from arena/detectors/gating_mechanisms/gating_mechanism.py rename to arena/detectors/gates/gating_mechanism.py diff --git a/arena/detectors/deepfake_detectors/npr_detector.py b/arena/detectors/npr_detector.py similarity index 91% rename from arena/detectors/deepfake_detectors/npr_detector.py rename to arena/detectors/npr_detector.py index 76ef41a..2b7c7b2 100644 --- a/arena/detectors/deepfake_detectors/npr_detector.py +++ b/arena/detectors/npr_detector.py @@ -5,10 +5,10 @@ from huggingface_hub import hf_hub_download from bitmind.image_transforms import base_transforms -from arena.detectors.NPR.networks.resnet import resnet50 -from arena.detectors.deepfake_detectors import DeepfakeDetector -from arena.detectors import DETECTOR_REGISTRY -from arena.detectors.NPR.config.constants import WEIGHTS_DIR +from arena.detectors.registry import DETECTOR_REGISTRY +from arena.detectors.deepfake_detector import DeepfakeDetector +from arena.architectures.NPR.networks.resnet import resnet50 +from arena.architectures.NPR.config.constants import WEIGHTS_DIR @DETECTOR_REGISTRY.register_module(module_name='NPR') diff --git a/arena/detectors/spsl_detector.py b/arena/detectors/spsl_detector.py new file mode 100644 index 0000000..c48a0dc --- /dev/null +++ b/arena/detectors/spsl_detector.py @@ -0,0 +1,139 @@ +import os +os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # Ignore INFO and WARN messages + +import random +import warnings +warnings.filterwarnings("ignore", category=FutureWarning) +from pathlib import Path + +import numpy as np +import torch +import torch.backends.cudnn as cudnn +import torchvision.transforms as transforms +import yaml +from PIL import Image +from huggingface_hub import hf_hub_download +import gc + +from arena.detectors.registry import DETECTOR_REGISTRY +from arena.detectors.deepfake_detector import DeepfakeDetector +from arena.architectures.SPSL.config.constants import CONFIGS_DIR, WEIGHTS_DIR +from arena.architectures.SPSL.detectors import DETECTOR + +@DETECTOR_REGISTRY.register_module(module_name='SPSL') +class SPSLDetector(DeepfakeDetector): + """ + DeepfakeDetector subclass that initializes a pretrained SPSL model + for binary classification of fake and real images. + + Attributes: + model_name (str): Name of the detector instance. + config (str): Name of the YAML file in deepfake_detectors/config/ to load + attributes from. + device (str): The type of device ('cpu' or 'cuda'). + """ + + def __init__(self, model_name: str = 'SPSL', config: str = 'spsl.yaml', device: str = 'cpu'): + super().__init__(model_name, config, device) + + def ensure_weights_are_available(self, weight_filename): + destination_path = Path(WEIGHTS_DIR) / Path(weight_filename) + if not destination_path.parent.exists(): + destination_path.parent.mkdir(parents=True, exist_ok=True) + if not destination_path.exists(): + model_path = hf_hub_download(self.hf_repo, weight_filename) + model = torch.load(model_path, map_location=self.device) + torch.save(model, destination_path) + + def load_train_config(self): + destination_path = Path(CONFIGS_DIR) / Path(self.train_config) + + if not destination_path.exists(): + local_config_path = hf_hub_download(self.hf_repo, self.train_config) + print(f"Downloaded {self.hf_repo}/{self.train_config} to {local_config_path}") + config_dict = {} + with open(local_config_path, 'r') as f: + config_dict = yaml.safe_load(f) + with open(destination_path, 'w') as f: + yaml.dump(config_dict, f, default_flow_style=False) + with destination_path.open('r') as f: + return yaml.safe_load(f) + else: + print(f"Loaded local config from {destination_path}") + with destination_path.open('r') as f: + return yaml.safe_load(f) + + def init_cudnn(self): + if self.train_config.get('cudnn'): + cudnn.benchmark = True + + def init_seed(self): + seed_value = self.train_config.get('manualSeed') + if seed_value: + random.seed(seed_value) + torch.manual_seed(seed_value) + torch.cuda.manual_seed_all(seed_value) + + def load_model(self): + self.train_config = self.load_train_config() + self.init_cudnn() + self.init_seed() + self.ensure_weights_are_available(self.weights) + pretrained_weights = self.train_config['pretrained'].split('/')[-1] + self.ensure_weights_are_available(pretrained_weights) + self.train_config['pretrained'] = str(Path(WEIGHTS_DIR) / pretrained_weights) + + model_class = DETECTOR[self.train_config['model_name']] + self.model = model_class(self.train_config).to(self.device) + self.model.eval() + weights_path = Path(WEIGHTS_DIR) / self.weights + checkpoint = torch.load(weights_path, map_location=self.device) + try: + self.model.load_state_dict(checkpoint, strict=True) + except RuntimeError as e: + raise e + + def preprocess(self, image, res=256): + """Preprocess the image for model inference. + + Returns: + torch.Tensor: The preprocessed image tensor, ready for model inference. + """ + # Convert image to RGB format to ensure consistent color handling. + image = image.convert('RGB') + + # Define transformation sequence for image preprocessing. + transform = transforms.Compose([ + transforms.Resize((res, res), interpolation=Image.LANCZOS), # Resize image to specified resolution. + transforms.ToTensor(), # Convert the image to a PyTorch tensor. + transforms.Normalize(mean=self.train_config['mean'], std=self.train_config['std']) # Normalize the image tensor. + ]) + + # Apply transformations and add a batch dimension for model inference. + image_tensor = transform(image).unsqueeze(0) + + # Move the image tensor to the specified device (e.g., GPU). + return image_tensor.to(self.device) + + def infer(self, image_tensor): + """ Perform inference using the model. """ + with torch.no_grad(): + pred_dict = self.model({'image': image_tensor}) + return pred_dict['prob'] + + def __call__(self, image: Image) -> float: + image_tensor = self.preprocess(image) + return self.infer(image_tensor) + + def free_memory(self): + """ Frees up memory by setting model and large data structures to None. """ + if self.model is not None: + self.model.cpu() # Move model to CPU to free up GPU memory (if applicable) + del self.model + self.model = None + + gc.collect() + + # If using GPUs and PyTorch, clear the cache as well + if torch.cuda.is_available(): + torch.cuda.empty_cache() diff --git a/arena/detectors/tutorial.md b/arena/detectors/tutorial.md new file mode 100644 index 0000000..06ac370 --- /dev/null +++ b/arena/detectors/tutorial.md @@ -0,0 +1,64 @@ +# Tutorial: Adding a New Deepfake Detector + +This tutorial will guide you through the process of adding a new deepfake detector to the DFD Arena framework. We'll use the SPSL (Spatial-Phase Shallow Learning) architecture, sourced from DeepfakeBench, as an example. + +## 1. Get Your Model Weights + +First, you need to have your model weights ready. This can be either your own architecture or a pretrained one from prior literature. For this example, we're using weights sourced from DeepfakeBench and hosted in a BitMind Hugging Face + +## 2. Add Your Model-Specific Files to the Architectures Directory + +Create a directory for your architecture within `arena/architectures/`. For our SPSL example, we've created [`arena/architectures/SPSL/`](../architectures/SPSL/). + +## 3. Define Your Detector Class + +Next, create a detector script within the `detectors` directory. For our SPSL example, we'll name it `spsl_detector.py`. In this new script, import all necessary dependencies and define your detector class. You may utilize the config parameter for model loading (see UCF for example usage). + +For those interested in using the detector registry system, which is particularly useful for multi-model detectors like CAMO, you can refer to [`registry.py`](registry.py) and [`camo_detector.py`](camo_detector.py) as examples. + +Here is an example of [`spsl_detector.py`](spsl_detector.py): +```python +from arena.detectors.registry import DETECTOR_REGISTRY +from arena.detectors.deepfake_detector import DeepfakeDetector +from arena.architectures.SPSL.config.constants import CONFIGS_DIR, WEIGHTS_DIR +from arena.architectures.SPSL.detectors import DETECTOR + +@DETECTOR_REGISTRY.register_module(module_name='SPSL') +class SPSLDetector(DeepfakeDetector): + def __init__(self, model_name: str = 'SPSL', config: str = 'spsl.yaml', device: str = 'cpu'): + super().__init__(model_name, config, device) + + ... + + def infer(self, image_tensor): + """ Perform inference using the model. """ + with torch.no_grad(): + pred_dict = self.model({'image': image_tensor}) + return pred_dict['prob'] + + def __call__(self, image: Image) -> float: + image_tensor = self.preprocess(image) + return self.infer(image_tensor) +``` + +## 4. Import and Register Your Detector + +Import your detector in the [`arena/detectors/__init__.py`](__init__.py) file: + +```python +from .spsl_detector import SPSLDetector +``` + +## 5. Evaluate Your Detector + +To evaluate your detector, you can use the existing evaluation scripts in the DFD Arena framework. Make sure your architecture is properly integrated and can be selected for evaluation using the configuration file you created. + +## 6. Set Up Gates (if necessary) + +If your architecture requires any specific gates or preprocessing steps, you can set them up in a similar manner to how you registered the architecture. Use the `GATE_REGISTRY` for this purpose. + +## Understanding the Registration Process + +The `DETECTOR_REGISTRY` is imported within `neurons/miner.py`. When you run `setup_miner_env.sh`, it generates a `miner.env` file that specifies which detector configuration to use. The miner script then loads this configuration, instantiates the corresponding architecture, and utilizes it within its forward function. + +By following these steps, you've successfully added a new deepfake detector architecture to the DFD Arena framework. You can now use and evaluate your detector alongside other implemented models. \ No newline at end of file diff --git a/arena/detectors/deepfake_detectors/ucf_detector.py b/arena/detectors/ucf_detector.py similarity index 95% rename from arena/detectors/deepfake_detectors/ucf_detector.py rename to arena/detectors/ucf_detector.py index a7a681c..b88361c 100644 --- a/arena/detectors/deepfake_detectors/ucf_detector.py +++ b/arena/detectors/ucf_detector.py @@ -16,12 +16,12 @@ import gc from bitmind.image_transforms import ucf_transforms, ConvertToRGB, CenterCrop, CLAHE -from arena.detectors.UCF.config.constants import CONFIGS_DIR, WEIGHTS_DIR -from arena.detectors.gating_mechanisms import FaceGate -from arena.detectors.UCF.detectors import DETECTOR -from arena.detectors.deepfake_detectors import DeepfakeDetector -from arena.detectors import DETECTOR_REGISTRY, GATE_REGISTRY +from arena.detectors.registry import DETECTOR_REGISTRY, GATE_REGISTRY +from arena.detectors.deepfake_detector import DeepfakeDetector +from arena.detectors.gates import FaceGate +from arena.architectures.UCF.config.constants import CONFIGS_DIR, WEIGHTS_DIR +from arena.architectures.UCF.detectors import DETECTOR from arena.utils.image_transforms import CLAHE diff --git a/arena/detectors/unit_tests/__init__.py b/arena/detectors/unit_tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/arena/detectors/deepfake_detectors/unit_tests/base_npr_weights.pth b/arena/detectors/unit_tests/base_npr_weights.pth similarity index 100% rename from arena/detectors/deepfake_detectors/unit_tests/base_npr_weights.pth rename to arena/detectors/unit_tests/base_npr_weights.pth diff --git a/arena/detectors/deepfake_detectors/unit_tests/sample_image.jpg b/arena/detectors/unit_tests/sample_image.jpg similarity index 100% rename from arena/detectors/deepfake_detectors/unit_tests/sample_image.jpg rename to arena/detectors/unit_tests/sample_image.jpg diff --git a/arena/detectors/deepfake_detectors/unit_tests/test_camo_detector.py b/arena/detectors/unit_tests/test_camo_detector.py similarity index 100% rename from arena/detectors/deepfake_detectors/unit_tests/test_camo_detector.py rename to arena/detectors/unit_tests/test_camo_detector.py diff --git a/arena/detectors/deepfake_detectors/unit_tests/test_npr_detector.py b/arena/detectors/unit_tests/test_npr_detector.py similarity index 100% rename from arena/detectors/deepfake_detectors/unit_tests/test_npr_detector.py rename to arena/detectors/unit_tests/test_npr_detector.py diff --git a/arena/detectors/deepfake_detectors/unit_tests/test_registry.py b/arena/detectors/unit_tests/test_registry.py similarity index 100% rename from arena/detectors/deepfake_detectors/unit_tests/test_registry.py rename to arena/detectors/unit_tests/test_registry.py diff --git a/arena/detectors/deepfake_detectors/unit_tests/test_ucf_detector.py b/arena/detectors/unit_tests/test_ucf_detector.py similarity index 100% rename from arena/detectors/deepfake_detectors/unit_tests/test_ucf_detector.py rename to arena/detectors/unit_tests/test_ucf_detector.py diff --git a/arena/utils/data.py b/arena/utils/data.py index 3a2db02..cb8427b 100644 --- a/arena/utils/data.py +++ b/arena/utils/data.py @@ -13,7 +13,6 @@ def load_datasets(datasets): ds['path'], huggingface_dataset_split=ds['split'], huggingface_dataset_name=ds.get('name', None), - create_splits=False, download_mode='reuse_cache_if_exists') for ds in datasets['fake'] ] @@ -23,7 +22,6 @@ def load_datasets(datasets): ds['path'], huggingface_dataset_split=ds['split'], huggingface_dataset_name=ds.get('name', None), - create_splits=False, download_mode='reuse_cache_if_exists') for ds in datasets['real'] ]