main 44afdc60956f cached
18 files
72.8 KB
19.5k tokens
88 symbols
1 requests
Download .txt
Repository: Xiaohan-Chen/transfer-learning-fault-diagnosis-pytorch
Branch: main
Commit: 44afdc60956f
Files: 18
Total size: 72.8 KB

Directory structure:
gitextract_75kh58ds/

├── Backbone/
│   ├── CNN1D.py
│   ├── MLPNet.py
│   └── ResNet1D.py
├── DANN.py
├── DDC.py
├── OSDABP.py
├── PreparData/
│   ├── CWRU.py
│   ├── __init__.py
│   └── preprocess.py
├── README.md
├── Utils/
│   ├── __init__.py
│   ├── logger.py
│   └── utils.py
├── classification.py
└── loss/
    ├── CORAL.py
    ├── MKMMD.py
    ├── MMDLinear.py
    └── __init__.py

================================================
FILE CONTENTS
================================================

================================================
FILE: Backbone/CNN1D.py
================================================
import torch.nn as nn

class CNN1D(nn.Module):
    def __init__(self, num_out = 10):
        super(CNN1D, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv1d(1,32,kernel_size=3,padding=1),
            nn.BatchNorm1d(32),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=2, padding=0)
            )
        self.layer2 = nn.Sequential(
            nn.Conv1d(32,64,kernel_size=3,padding=1),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=2, padding=0)
            )
        self.layer3 = nn.Sequential(
            nn.Conv1d(64,64,kernel_size=3,padding=1),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=2, padding=0)
            )
        self.avgpool = nn.AdaptiveAvgPool1d(1) # output (64,1)
        self.fc = nn.Sequential(nn.Linear(64,num_out, nn.Dropout(0.5)))

    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.avgpool(x)
        x = x.view(-1,64)

        return x

================================================
FILE: Backbone/MLPNet.py
================================================
import torch.nn as nn

class MLPNet(nn.Module):
    def __init__(self, num_in = 1024, num_out = 10):
        super(MLPNet, self).__init__()
        self.fc1 = nn.Sequential(
            nn.Linear(num_in,512),
            nn.BatchNorm1d(512),
            nn.ReLU()
            )
        self.fc2 = nn.Sequential(
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.ReLU()
            )
        self.fc3 = nn.Sequential(
            nn.Linear(256, 128),
            nn.BatchNorm1d(128),
            nn.ReLU()
            )
        self.fc4 = nn.Sequential(
            nn.Linear(128, 64),
            nn.BatchNorm1d(64),
            nn.ReLU()
            )
        self.fc5 = nn.Linear(64,num_out)

    def forward(self, x):
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)
        x = self.fc4(x)

        return x

================================================
FILE: Backbone/ResNet1D.py
================================================
# one-dimentional ResNet source code reference: https://github.com/ZhaoZhibin/UDTL/blob/master/models/resnet18_1d.py

import torch.nn as nn
import torch.utils.model_zoo as model_zoo

model_urls = {
    'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
    'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
    'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
    'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
    'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
}

def conv3x1(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv1d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)

def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv1d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x1(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm1d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x1(planes, planes)
        self.bn2 = nn.BatchNorm1d(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    # For ResNet50, ResNet101, ResNet152
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.conv1 = conv1x1(inplanes, planes)
        self.bn1 = nn.BatchNorm1d(planes)
        self.conv2 = conv3x1(planes, planes, stride)
        self.bn2 = nn.BatchNorm1d(planes)
        self.conv3 = conv1x1(planes, planes * self.expansion)
        self.bn3 = nn.BatchNorm1d(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out


class ResNet(nn.Module):

    def __init__(self, block, layers, in_channel=1, out_channel=10, zero_init_residual=False):
        super(ResNet, self).__init__()
        self.inplanes = 64
        self.conv1 = nn.Conv1d(in_channel, 64, kernel_size=7, stride=2, padding=3,
                               bias=False)
        self.bn1 = nn.BatchNorm1d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        self.avgpool = nn.AdaptiveAvgPool1d(1)  # output (512, 1)


        for m in self.modules():
            if isinstance(m, nn.Conv1d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm1d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck):
                    nn.init.constant_(m.bn3.weight, 0)
                elif isinstance(m, BasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                conv1x1(self.inplanes, planes * block.expansion, stride),
                nn.BatchNorm1d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)

        return x


def resnet18(pretrained=False, **kwargs):
    """Constructs a ResNet-18 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
    return model


# convnet without the last layer
class resnet18_features(nn.Module):
    def __init__(self, pretrained=False):
        super(resnet18_features, self).__init__()
        self.model_resnet18 = resnet18(pretrained)
        self.__in_features = 512

    def forward(self, x):
        x = self.model_resnet18(x)
        return x

    def output_num(self):
        return self.__in_features

================================================
FILE: DANN.py
================================================
import argparse
import os
import numpy as np
from Utils.logger import setlogger

from turtle import forward
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models

from Backbone import ResNet1D, MLPNet, CNN1D
from loss import MKMMD, MMDLinear, CORAL
from PreparData.CWRU import CWRUloader
import Utils.utils as utils

from tqdm import *
import warnings
import logging

# ===== Define argments =====
def parse_args():
    parser = argparse.ArgumentParser(description='Implementation of Domain Adversarial Neural Networks')

    # task setting
    parser.add_argument("--log_file", type=str, default="./logs/DANN.log", help="log file path")

    # dataset information
    parser.add_argument("--datadir", type=str, default="./datasets", help="data directory")
    parser.add_argument("--source_dataname", type=str, default="CWRU", choices=["CWRU", "PU"], help="choice a dataset")
    parser.add_argument("--target_dataname", type=str, default="CWRU", choices=["CWRU", "PU"], help="choice a dataset")
    parser.add_argument("--s_load", type=int, default=3, help="source domain working condition")
    parser.add_argument("--t_load", type=int, default=2, help="target domain working condition")
    parser.add_argument("--s_label_set", type=list, default=[0,1,2,3,4,5,6,7,8,9], help="source domain label set")
    parser.add_argument("--t_label_set", type=list, default=[0,1,2,3,4,5,6,7,8,9], help="target domain label set")
    parser.add_argument("--val_rat", type=float, default=0.3, help="training-validation rate")
    parser.add_argument("--test_rat", type=float, default=0.5, help="validation-test rate")
    parser.add_argument("--seed", type=int, default="29")

    # pre-processing
    parser.add_argument("--fft", type=bool, default=False, help="FFT preprocessing")
    parser.add_argument("--window", type=int, default=128, help="time window, if not augment data, window=1024")
    parser.add_argument("--normalization", type=str, default="0-1", choices=["None", "0-1", "mean-std"], help="normalization option")
    parser.add_argument("--savemodel", type=bool, default=False, help="whether save pre-trained model in the classification task")
    parser.add_argument("--pretrained", type=bool, default=False, help="whether use pre-trained model in transfer learning tasks")

    # backbone
    parser.add_argument("--backbone", type=str, default="ResNet1D", choices=["ResNet1D", "ResNet2D", "MLPNet", "CNN1D"])
    # if   backbone in ("ResNet1D", "CNN1D"),  data shape: (batch size, 1, 1024)
    # elif backbone == "ResNet2D",             data shape: (batch size, 3, 32, 32)
    # elif backbone == "MLPNet",               data shape: (batch size, 1024)


    # optimization & training
    parser.add_argument("--num_workers", type=int, default=0, help="the number of dataloader workers")
    parser.add_argument("--batch_size", type=int, default=256)
    parser.add_argument("--max_epoch", type=int, default=100)
    parser.add_argument("--lr", type=float, default=1e-3, help="learning rate")
    parser.add_argument('--lr_scheduler', type=str, default='stepLR', choices=['step', 'exp', 'stepLR', 'fix'], help='the learning rate schedule')
    parser.add_argument('--gamma', type=float, default=0.8, help='learning rate scheduler parameter for step and exp')
    parser.add_argument('--steps', type=str, default='30, 120', help='the learning rate decay for step and stepLR')
    parser.add_argument("--optimizer", type=str, default="adam", choices=["adam", "sgd"])


    args = parser.parse_args()
    return args

# ===== Build Model =====
class FeatureNet(nn.Module):
    def __init__(self, args):
        super(FeatureNet, self).__init__()
        if args.backbone == "ResNet1D":
            self.feature_net = ResNet1D.resnet18()
        elif args.backbone == "ResNet2D":
            self.model_ft = models.resnet18(pretrained=True)
            self.bottleneck = nn.Sequential(nn.Linear(self.model_ft.fc.out_features, 512), nn.ReLU(), nn.Dropout(0.5))
            self.feature_net = nn.Sequential(self.model_ft, self.bottleneck)
        elif args.backbone == "MLPNet":
            if args.fft:
                self.feature_net = MLPNet.MLPNet(num_in=512)
            else:
                self.feature_net = MLPNet.MLPNet()
        elif args.backbone == "CNN1D":
            self.feature_net = CNN1D.CNN1D()
        else:
            raise Exception("model not implement")

    def forward(self, x):
        logits = self.feature_net(x)

        return logits

class Classifier(nn.Module):
    def __init__(self, args, num_out=10):
        super(Classifier, self).__init__()
        if args.backbone in ("ResNet1D", "ResNet2D"):
            self.classifier = nn.Sequential(nn.Linear(512,num_out, nn.Dropout(0.5)))
        if args.backbone in ("MLPNet", "CNN1D"):
            self.classifier = nn.Sequential(nn.Linear(64,num_out, nn.Dropout(0.5)))

    def forward(self, logits):
        outputs = self.classifier(logits)

        return outputs

# Define the discriminator
def calc_coeff(iter_num, high=1.0, low=0.0, alpha=10.0, max_iter=10000.0):
    return np.float(2.0 * (high - low) / (1.0 + np.exp(-alpha * iter_num / max_iter)) - (high - low) + low)

## The hook will be called every time a gradient with respect to the Tensor is computed.
## https://pytorch.org/docs/stable/generated/torch.Tensor.register_hook.html?highlight=register_hook#torch.Tensor.register_hook
def grl_hook(coeff):
    def fun1(grad):
        return -coeff * grad.clone()
    return fun1

class Discriminator(nn.Module):
    def __init__(self, args, num_out = 1, max_iter=10000.0, trade_off_adversarial='Cons', lam_adversarial=1.0):
        super(Discriminator, self).__init__()
        if args.backbone in ("ResNet1D", "ResNet2D"):
            self.domain_classifier = nn.Sequential(
                nn.Linear(512,128, nn.Dropout(0.5)),
                nn.BatchNorm1d(128),
                nn.ReLU(),
                nn.Linear(128, num_out)
                )
        elif args.backbone in ("MLPNet", "CNN1D"):
            self.domain_classifier = nn.Sequential(
                nn.Linear(64,32, nn.Dropout(0.5)),
                nn.BatchNorm1d(32),
                nn.ReLU(),
                nn.Linear(32, num_out)
                )
        self.sigmoid = nn.Sigmoid()

        # parameters
        self.iter_num = 0
        self.alpha = 10
        self.low = 0.0
        self.high = 1.0
        self.max_iter = max_iter
        self.trade_off_adversarial = trade_off_adversarial
        self.lam_adversarial = lam_adversarial
    
    def forward(self, x):
        if self.training:
            self.iter_num += 1
        if self.trade_off_adversarial == "Cons":
            coeff = self.lam_adversarial
        elif self.trade_off_adversarial == "Step":
            coeff = calc_coeff(self.iter_num, self.high, self.low,\
                self.alpha, self.max_iter)
        else:
            raise Exception("loss not implement")
        x = x * 1.0
        x.register_hook(grl_hook(coeff))
        x = self.domain_classifier(x)
        x = self.sigmoid(x)
        return x

# ===== Load Data =====
def loaddata(args):
    if args.source_dataname == "CWRU":
        source_data, source_label = CWRUloader(args, args.s_load, args.s_label_set)

    source_data, source_label = np.concatenate(source_data, axis=0), np.concatenate(source_label, axis=0)
    
    if args.target_dataname == "CWRU":
        target_data, target_label = CWRUloader(args, args.t_load, args.t_label_set)

    target_data, target_label = np.concatenate(target_data, axis=0), np.concatenate(target_label, axis=0)

    source_loader, _, _ = utils.DataSplite(args, source_data, source_label)
    target_trainloader, target_valloader, target_testloader = utils.DataSplite(args, target_data, target_label)
    
    return source_loader, target_trainloader, target_valloader, target_testloader

# ===== Test the Model =====
def tester(featurenet, classifier, dataloader):
    featurenet.eval()
    classifier.eval()
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    correct_num, total_num = 0, 0
    for i, (x_batch, y_batch) in enumerate(dataloader):
        x_batch, y_batch = x_batch.to(device), y_batch.to(device)

        # compute model cotput and loss
        logtis_batch = featurenet(x_batch)
        output_batch = classifier(logtis_batch)

        pre = torch.max(output_batch.cpu(), 1)[1].numpy()
        y = y_batch.cpu().numpy()
        correct_num += (pre == y).sum()
        total_num += len(y)
    accuracy = (correct_num / total_num) * 100.0
    return accuracy

# ===== Train the Model =====
def trainer(args):
    # Consider the gpu or cpu condition
    if torch.cuda.is_available():
        device = torch.device("cuda")
        device_count = torch.cuda.device_count()
        logging.info('using {} gpus'.format(device_count))
        assert args.batch_size % device_count == 0, "batch size should be divided by device count"
    else:
        warnings.warn("gpu is not available")
        device = torch.device("cpu")
        device_count = 1
        logging.info('using {} cpu'.format(device_count))
    
    # load the dataset
    source_trainloader, target_trainloader, target_valloader, target_testloader = loaddata(args)

    # load the model
    featurenet = FeatureNet(args)
    classifier = Classifier(args, num_out=len(args.t_label_set))
    discriminator = Discriminator(args)

    # load the checkpoint
    if args.pretrained:
        if args.backbone != "ResNet2D": # pretrained ResNet2D model is downloaded from torchvision module
            if not args.fft:
                path = "./checkpoints/{}_checkpoint.tar".format(args.backbone)
            else:
                path = "./checkpoints/{}FFT_checkpoint.tar".format(args.backbone)
            featurenet.load_state_dict(torch.load(path))

    parameter_list = [{"params": featurenet.parameters(), "lr": 0.5*args.lr},
                        {"params": classifier.parameters(), "lr": args.lr},
                       {"params": discriminator.parameters(), "lr": args.lr}]

    # Define optimizer and learning rate decay
    optimizer, lr_scheduler = utils.optimizer(args, parameter_list)

    ## define loss function
    loss_cls = nn.CrossEntropyLoss()
    loss_adver = nn.BCELoss()

    featurenet.to(device)
    classifier.to(device)
    discriminator.to(device)

    # train
    best_acc = 0.0
    meters = {"acc_source_train":[], "acc_target_train": [], "acc_target_val": []}

    for epoch in range(args.max_epoch):
        featurenet.train()
        classifier.train()
        with tqdm(total=len(target_trainloader), leave=False) as pbar:
            for i, ((x_s_batch, y_s_batch), (x_t_batch, y_t_batch)) in enumerate(zip(source_trainloader,target_trainloader)):

                if len(y_s_batch) != len(y_t_batch):
                    break
                
                batch_num = x_s_batch.size(0)
                
                domain_label_source = torch.ones(batch_num).float()
                domain_label_target = torch.zeros(batch_num).float()

                inputs = torch.cat((x_s_batch, x_t_batch), dim=0)
                domain_label = torch.cat((domain_label_source, domain_label_target), dim=0)

                # move to GPU if available
                inputs = inputs.to(device)
                s_labels = y_s_batch.to(device)
                t_labels = y_t_batch.to(device)
                domain_label = domain_label.to(device)

                # compute model cotput and loss
                logits = featurenet(inputs)
                outputs = classifier(logits)
                domain_outputs = discriminator(logits)


                classification_loss = loss_cls(outputs.narrow(0, 0, batch_num), s_labels.long())
                adversarial_loss = loss_adver(domain_outputs.squeeze(), domain_label)
                loss = classification_loss + adversarial_loss


                # clear previous gradients, compute gradients
                optimizer.zero_grad()
                loss.backward()

                # performs updates using calculated gradients
                optimizer.step()

                # evaluate
                # training accuracy
                acc_source_train = utils.accuracy(outputs.narrow(0, 0, batch_num), s_labels)
                acc_target_train = utils.accuracy(outputs.narrow(0, batch_num, batch_num), t_labels)              

                pbar.update()
        
        # update lr
        if lr_scheduler is not None:
            lr_scheduler.step()

        val_acc = tester(featurenet, classifier, target_valloader)
        if val_acc > best_acc:
            best_acc = val_acc
            if args.savemodel:
                utils.save_model(featurenet, args)
        
        logging.info("Epoch: {:>3}/{}, loss_cls: {:.4f}, loss: {:.4f}, source_train_acc: {:>6.2f}%, target_train_acc: {:>6.2f}%, target_val_acc: {:>6.2f}%".format(\
                epoch+1, args.max_epoch, classification_loss, loss, acc_source_train, acc_target_train, val_acc))
        meters["acc_source_train"].append(acc_source_train)
        meters["acc_target_train"].append(acc_target_train)
        meters["acc_target_val"].append(val_acc)

    logging.info("Best accuracy: {:.4f}".format(best_acc))
    utils.save_log(meters, "./logs/DDC_{}_{}_meters.pkl".format(args.backbone, args.max_epoch))

    logging.info("="*15+"Done!"+"="*15)

if __name__ == "__main__":

    args = parse_args()

    # set the logger
    if not os.path.exists("./logs"):
        os.makedirs("./logs")
    setlogger(args.log_file)

    # save the args
    for k, v in args.__dict__.items():
        logging.info("{}: {}".format(k, v))
    
    trainer(args)

================================================
FILE: DDC.py
================================================
import argparse
import os
import numpy as np
from Utils.logger import setlogger

from turtle import forward
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models

from Backbone import ResNet1D, MLPNet, CNN1D
from loss import MKMMD, MMDLinear, CORAL
from PreparData.CWRU import CWRUloader
import Utils.utils as utils

from tqdm import *
import warnings
import logging

# ===== Define argments =====
def parse_args():
    parser = argparse.ArgumentParser(description='Implementation of Deep Domain Confusion networks')

    # task setting
    parser.add_argument("--log_file", type=str, default="./logs/DDC.log", help="log file path")

    # dataset information
    parser.add_argument("--datadir", type=str, default="./datasets", help="data directory")
    parser.add_argument("--source_dataname", type=str, default="CWRU", choices=["CWRU"], help="choice a dataset")
    parser.add_argument("--target_dataname", type=str, default="CWRU", choices=["CWRU"], help="choice a dataset")
    parser.add_argument("--s_load", type=int, default=3, help="source domain working condition")
    parser.add_argument("--t_load", type=int, default=2, help="target domain working condition")
    parser.add_argument("--s_label_set", type=list, default=[0,1,2,3,4,5,6,7,8,9], help="source domain label set")
    parser.add_argument("--t_label_set", type=list, default=[0,1,2,3,4,5,6,7,8,9], help="target domain label set")
    parser.add_argument("--val_rat", type=float, default=0.3, help="training-validation rate")
    parser.add_argument("--test_rat", type=float, default=0.5, help="validation-test rate")
    parser.add_argument("--seed", type=int, default="29")

    # pre-processing
    parser.add_argument("--fft", type=bool, default=False, help="FFT preprocessing")
    parser.add_argument("--window", type=int, default=128, help="time window, if not augment data, window=1024")
    parser.add_argument("--normalization", type=str, default="0-1", choices=["None", "0-1", "mean-std"], help="normalization option")
    parser.add_argument("--savemodel", type=bool, default=False, help="whether save pre-trained model in the classification task")
    parser.add_argument("--pretrained", type=bool, default=False, help="whether use pre-trained model in transfer learning tasks")

    # backbone
    parser.add_argument("--backbone", type=str, default="ResNet1D", choices=["ResNet1D", "ResNet2D", "MLPNet", "CNN1D"])
    # if   backbone in ("ResNet1D", "CNN1D"),  data shape: (batch size, 1, 1024)
    # elif backbone == "ResNet2D",             data shape: (batch size, 3, 32, 32)
    # elif backbone == "MLPNet",               data shape: (batch size, 1024)


    # optimization & training
    parser.add_argument("--num_workers", type=int, default=0, help="the number of dataloader workers")
    parser.add_argument("--batch_size", type=int, default=256)
    parser.add_argument("--max_epoch", type=int, default=100)
    parser.add_argument("--lr", type=float, default=1e-3, help="learning rate")
    parser.add_argument('--lr_scheduler', type=str, default='stepLR', choices=['step', 'exp', 'stepLR', 'fix'], help='the learning rate schedule')
    parser.add_argument('--gamma', type=float, default=0.8, help='learning rate scheduler parameter for step and exp')
    parser.add_argument('--steps', type=str, default='30, 120', help='the learning rate decay for step and stepLR')
    parser.add_argument("--optimizer", type=str, default="adam", choices=["adam", "sgd"])
    parser.add_argument("--kernel", type=str, default='Linear', choices=["Linear", "CORAL"])


    args = parser.parse_args()
    return args

# ===== Build Model =====
class FeatureNet(nn.Module):
    def __init__(self, args):
        super(FeatureNet, self).__init__()
        if args.backbone == "ResNet1D":
            self.feature_net = ResNet1D.resnet18()
        elif args.backbone == "ResNet2D":
            self.model_ft = models.resnet18(pretrained=True)
            self.bottleneck = nn.Sequential(nn.Linear(self.model_ft.fc.out_features, 512), nn.ReLU(), nn.Dropout(0.5))
            self.feature_net = nn.Sequential(self.model_ft, self.bottleneck)
        elif args.backbone == "MLPNet":
            if args.fft:
                self.feature_net = MLPNet.MLPNet(num_in=512)
            else:
                self.feature_net = MLPNet.MLPNet()
        elif args.backbone == "CNN1D":
            self.feature_net = CNN1D.CNN1D()
        else:
            raise Exception("model not implement")

    def forward(self, x):
        logits = self.feature_net(x)

        return logits

class Classifier(nn.Module):
    def __init__(self, args, num_out=10):
        super(Classifier, self).__init__()
        if args.backbone in ("ResNet1D", "ResNet2D"):
            self.classifier = nn.Sequential(nn.Linear(512,num_out, nn.Dropout(0.5)))
        if args.backbone in ("MLPNet", "CNN1D"):
            self.classifier = nn.Sequential(nn.Linear(64,num_out, nn.Dropout(0.5)))

    def forward(self, logits):
        outputs = self.classifier(logits)

        return outputs

# ===== Load Data =====
def loaddata(args):
    if args.source_dataname == "CWRU":
        source_data, source_label = CWRUloader(args, args.s_load, args.s_label_set)
    else:
        raise NotImplementedError("Source dataset {} not implemented.".format(args.source_dataname))

    source_data, source_label = np.concatenate(source_data, axis=0), np.concatenate(source_label, axis=0)
    
    if args.target_dataname == "CWRU":
        target_data, target_label = CWRUloader(args, args.t_load, args.t_label_set)
    else:
        raise NotImplementedError("Target dataset {} not implemented.".format(args.target_dataname))

    target_data, target_label = np.concatenate(target_data, axis=0), np.concatenate(target_label, axis=0)

    source_loader, _, _ = utils.DataSplite(args, source_data, source_label)
    target_trainloader, target_valloader, target_testloader = utils.DataSplite(args, target_data, target_label)
    
    return source_loader, target_trainloader, target_valloader, target_testloader

# ===== Test the Model =====
def tester(featurenet, classifier, dataloader):
    featurenet.eval()
    classifier.eval()
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    correct_num, total_num = 0, 0
    for i, (x_batch, y_batch) in enumerate(dataloader):
        x_batch, y_batch = x_batch.to(device), y_batch.to(device)

        # compute model cotput and loss
        logtis_batch = featurenet(x_batch)
        output_batch = classifier(logtis_batch)

        pre = torch.max(output_batch.cpu(), 1)[1].numpy()
        y = y_batch.cpu().numpy()
        correct_num += (pre == y).sum()
        total_num += len(y)
    accuracy = (correct_num / total_num) * 100.0
    return accuracy


# ===== Train the Model =====
def trainer(args):
    # Consider the gpu or cpu condition
    if torch.cuda.is_available():
        device = torch.device("cuda")
        device_count = torch.cuda.device_count()
        logging.info('using {} gpus'.format(device_count))
        assert args.batch_size % device_count == 0, "batch size should be divided by device count"
    else:
        warnings.warn("gpu is not available")
        device = torch.device("cpu")
        device_count = 1
        logging.info('using {} cpu'.format(device_count))
    
    # load the dataset
    source_trainloader, target_trainloader, target_valloader, target_testloader = loaddata(args)

    # load the model
    featurenet = FeatureNet(args)
    classifier = Classifier(args, num_out=len(args.t_label_set))

    # load the checkpoint
    if args.pretrained:
        if args.backbone != "ResNet2D": # pretrained ResNet2D model is downloaded from torchvision module
            if not args.fft:
                path = "./checkpoints/{}_checkpoint.tar".format(args.backbone)
            else:
                path = "./checkpoints/{}FFT_checkpoint.tar".format(args.backbone)
            featurenet.load_state_dict(torch.load(path))

    parameter_list = [{"params": featurenet.parameters(), "lr": args.lr},
                       {"params": classifier.parameters(), "lr": args.lr}]

    # Define optimizer and learning rate decay
    optimizer, lr_scheduler = utils.optimizer(args, parameter_list)

    # define loss function
    loss_cls = nn.CrossEntropyLoss()
    if args.kernel == "Linear":
        loss_dis = MMDLinear.MMDLinear
    elif args.kernel == "CORAL":
        loss_dis = CORAL.CORAL_loss
    else:
        raise NotImplemented("Kernel {} not implemented.".format(args.kernel))

    featurenet.to(device)
    classifier.to(device)

    # train
    best_acc = 0.0
    meters = {"acc_source_train":[], "acc_target_train": [], "acc_target_val": []}

    for epoch in range(args.max_epoch):
        featurenet.train()
        classifier.train()
        with tqdm(total=len(target_trainloader), leave=False) as pbar:
            for i, ((x_s_batch, y_s_batch), (x_t_batch, y_t_batch)) in enumerate(zip(source_trainloader,target_trainloader)):

                if len(y_s_batch) != len(y_t_batch):
                    break
                batch_num = x_s_batch.size(0)

                inputs = torch.cat((x_s_batch, x_t_batch), dim=0)

                # move to GPU if available
                inputs = inputs.to(device)
                s_labels = y_s_batch.to(device)
                t_labels = y_t_batch.to(device)

                # compute model cotput and loss
                logits = featurenet(inputs)
                outputs = classifier(logits)

                classification_loss = loss_cls(outputs.narrow(0, 0, s_labels.size(0)), s_labels.long())
                distance_loss = loss_dis(outputs.view(outputs.size(0),-1).narrow(0, 0, s_labels.size(0)),\
                                            outputs.view(outputs.size(0),-1).narrow(0, s_labels.size(0), s_labels.size(0)))
                loss = classification_loss + distance_loss 

                # clear previous gradients, compute gradients
                optimizer.zero_grad()
                loss.backward()

                # performs updates using calculated gradients
                optimizer.step()

                # evaluate
                # training accuracy
                acc_source_train = utils.accuracy(outputs.narrow(0, 0, batch_num), s_labels)
                acc_target_train = utils.accuracy(outputs.narrow(0, batch_num, batch_num), t_labels)              

                pbar.update()
        
        # update lr
        if lr_scheduler is not None:
            lr_scheduler.step()

        val_acc = tester(featurenet, classifier, target_valloader)
        if val_acc > best_acc:
            best_acc = val_acc
            if args.savemodel:
                utils.save_model(featurenet, args)
        
        logging.info("Epoch: {:>3}/{}, loss_cls: {:.4f}, loss: {:.4f}, source_train_acc: {:>6.2f}%, target_train_acc: {:>6.2f}%, target_val_acc: {:>6.2f}%".format(\
                epoch+1, args.max_epoch, classification_loss, loss, acc_source_train, acc_target_train, val_acc))
        meters["acc_source_train"].append(acc_source_train)
        meters["acc_target_train"].append(acc_target_train)
        meters["acc_target_val"].append(val_acc)

    logging.info("Best accuracy: {:.4f}".format(best_acc))
    utils.save_log(meters, "./logs/DDC_{}_{}_meters.pkl".format(args.backbone, args.max_epoch))

    logging.info("="*15+"Done!"+"="*15)

if __name__ == "__main__":

    args = parse_args()

    # set the logger
    if not os.path.exists("./logs"):
        os.makedirs("./logs")
    setlogger(args.log_file)

    # save the args
    for k, v in args.__dict__.items():
        logging.info("{}: {}".format(k, v))
    
    trainer(args)

================================================
FILE: OSDABP.py
================================================
import argparse
import os
import numpy as np
from Utils.logger import setlogger

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from torch.autograd import Function

from Backbone import ResNet1D, MLPNet, CNN1D
from PreparData.CWRU import CWRUloader
import Utils.utils as utils

from tqdm import *
import warnings
import logging

# ===== Define argments =====
def parse_args():
    parser = argparse.ArgumentParser(description='Implementation of Deep Domain Confusion networks')

    # task setting
    parser.add_argument("--log_file", type=str, default="./logs/OSDABP.log", help="log file path")

    # dataset information
    parser.add_argument("--datadir", type=str, default="./datasets", help="data directory")
    parser.add_argument("--source_dataname", type=str, default="CWRU", choices=["CWRU", "PU"], help="choice a dataset")
    parser.add_argument("--target_dataname", type=str, default="CWRU", choices=["CWRU", "PU"], help="choice a dataset")
    parser.add_argument("--s_load", type=int, default=3, help="source domain working condition")
    parser.add_argument("--t_load", type=int, default=2, help="target domain working condition")
    parser.add_argument("--s_label_set", type=list, default=[0,1,2,3,4,5], help="source domain label set")
    parser.add_argument("--t_label_set", type=list, default=[0,1,2,3,4,5,6,7,8,9], help="target domain label set")
    parser.add_argument("--val_rat", type=float, default=0.3, help="training-validation rate")
    parser.add_argument("--test_rat", type=float, default=0.5, help="validation-test rate")
    parser.add_argument("--seed", type=int, default="29")

    # pre-processing
    parser.add_argument("--fft", type=bool, default=False, help="FFT preprocessing")
    parser.add_argument("--window", type=int, default=128, help="time window, if not augment data, window=1024")
    parser.add_argument("--normalization", type=str, default="0-1", choices=["None", "0-1", "mean-std"], help="normalization option")
    parser.add_argument("--savemodel", type=bool, default=False, help="whether save pre-trained model in the classification task")
    parser.add_argument("--pretrained", type=bool, default=False, help="whether use pre-trained model in transfer learning tasks")

    # backbone
    parser.add_argument("--backbone", type=str, default="ResNet1D", choices=["ResNet1D", "ResNet2D", "MLPNet", "CNN1D"])
    # if   backbone in ("ResNet1D", "CNN1D"),  data shape: (batch size, 1, 1024)
    # elif backbone == "ResNet2D",             data shape: (batch size, 3, 32, 32)
    # elif backbone == "MLPNet",               data shape: (batch size, 1024)


    # optimization & training
    parser.add_argument("--num_workers", type=int, default=0, help="the number of dataloader workers")
    parser.add_argument("--batch_size", type=int, default=256)
    parser.add_argument("--max_epoch", type=int, default=100)
    parser.add_argument("--lr", type=float, default=1e-3, help="learning rate")
    parser.add_argument('--lr_scheduler', type=str, default='stepLR', choices=['step', 'exp', 'stepLR', 'fix'], help='the learning rate schedule')
    parser.add_argument('--gamma', type=float, default=0.8, help='learning rate scheduler parameter for step and exp')
    parser.add_argument('--steps', type=str, default='30, 120', help='the learning rate decay for step and stepLR')
    parser.add_argument("--optimizer", type=str, default="adam", choices=["adam", "sgd"])


    args = parser.parse_args()
    return args

# ===== Build Model =====
class GradReverse(Function):
    @staticmethod
    def forward(ctx, x):
        return x.view_as(x)

    @staticmethod
    def backward(ctx, grad_output):
        output = grad_output.neg()
        return output, None

def grad_reverse(x):
    return GradReverse.apply(x)

# define the model
class FeatureNet(nn.Module):
    def __init__(self, args):
        super(FeatureNet, self).__init__()
        if args.backbone == "ResNet1D":
            self.feature_net = ResNet1D.resnet18()
        elif args.backbone == "ResNet2D":
            self.model_ft = models.resnet18(pretrained=True)
            self.bottleneck = nn.Sequential(nn.Linear(self.model_ft.fc.out_features, 512), nn.ReLU(), nn.Dropout(0.2))
            self.feature_net = nn.Sequential(self.model_ft, self.bottleneck)
        elif args.backbone == "MLPNet":
            self.feature_net = MLPNet.MLPNet()
        elif args.backbone == "CNN1D":
            self.feature_net = CNN1D.CNN1D()
        else:
            raise Exception("model not implement")

    def forward(self, x):
        logits = self.feature_net(x)

        return logits

class Classifier(nn.Module):
    def __init__(self, args, num_out=10):
        super(Classifier, self).__init__()
        if args.backbone in ("ResNet1D", "ResNet2D"):
            self.classifier = nn.Sequential(nn.Linear(512,num_out, nn.Dropout(0.5)))
        if args.backbone in ("MLPNet", "CNN1D"):
            self.classifier = nn.Sequential(nn.Linear(64,num_out, nn.Dropout(0.5)))

    def forward(self, logits, reverse = False):
        if reverse:
            logits = grad_reverse(logits)
        outputs = self.classifier(logits)

        return outputs

# ===== Load Data =====
def loaddata(args):
    if args.source_dataname == "CWRU":
        source_data, source_label = CWRUloader(args, args.s_load, args.s_label_set)

    source_data, source_label = np.concatenate(source_data, axis=0), np.concatenate(source_label, axis=0)
    
    if args.target_dataname == "CWRU":
        target_data, target_label = CWRUloader(args, args.t_load, args.t_label_set)

    target_data, target_label = np.concatenate(target_data, axis=0), np.concatenate(target_label, axis=0)

    source_loader, _, _ = utils.DataSplite(args, source_data, source_label)
    target_trainloader, target_valloader, target_testloader = utils.DataSplite(args, target_data, target_label)
    
    return source_loader, target_trainloader, target_valloader, target_testloader

# ===== Define Loss Function =====
def bce_loss(output, target):
    output_neg = 1 - output
    target_neg = 1 - target
    result = torch.mean(target * torch.log(output + 1e-6))
    result += torch.mean(target_neg * torch.log(output_neg + 1e-6))
    return -torch.mean(result)

# ===== Test the Model =====
def tester(featurenet, classifier, dataloader):
    featurenet.eval()
    classifier.eval()
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    correct_num, total_num = 0, 0
    for i, (x_batch, y_batch) in enumerate(dataloader):
        x_batch, y_batch = x_batch.to(device), y_batch.to(device)

        # compute model cotput and loss
        logtis_batch = featurenet(x_batch)
        output_batch = classifier(logtis_batch)

        pre = torch.max(output_batch.cpu(), 1)[1].numpy()
        y = y_batch.cpu().numpy()
        correct_num += (pre == y).sum()
        total_num += len(y)
    accuracy = (correct_num / total_num) * 100.0
    return accuracy


# ===== Train the Model =====
def trainer(args):
    # Consider the gpu or cpu condition
    if torch.cuda.is_available():
        device = torch.device("cuda")
        device_count = torch.cuda.device_count()
        logging.info('using {} gpus'.format(device_count))
        assert args.batch_size % device_count == 0, "batch size should be divided by device count"
    else:
        warnings.warn("gpu is not available")
        device = torch.device("cpu")
        device_count = 1
        logging.info('using {} cpu'.format(device_count))
    
    # load the dataset
    source_trainloader, target_trainloader, target_valloader, target_testloader = loaddata(args)

    num_out = len(args.s_label_set)+1

    # load the model
    featurenet = FeatureNet(args)
    classifier = Classifier(args, num_out=num_out)

    # load the checkpoint
    if args.pretrained:
        if args.backbone != "ResNet2D": # pretrained ResNet2D model is downloaded from torchvision module
            if not args.fft:
                path = "./checkpoints/{}_checkpoint.tar".format(args.backbone)
            else:
                path = "./checkpoints/{}FFT_checkpoint.tar".format(args.backbone)
            featurenet.load_state_dict(torch.load(path))

    parameter_list = [{"params": featurenet.parameters(), "lr": args.lr},
                       {"params": classifier.parameters(), "lr": args.lr}]

    # Define optimizer and learning rate decay
    optimizer, lr_scheduler = utils.optimizer(args, parameter_list)

    ## define loss function
    criterion = nn.CrossEntropyLoss()

    featurenet.to(device)
    classifier.to(device)

    # train
    best_acc = 0.0
    meters = {"acc_source_train":[], "acc_target_train": [], "acc_target_val": []}

    for epoch in range(args.max_epoch):
        featurenet.train()
        classifier.train()
        with tqdm(total=len(target_trainloader), leave=False) as pbar:
            for i, ((x_s_batch, y_s_batch), (x_t_batch, y_t_batch)) in enumerate(zip(source_trainloader,target_trainloader)):

                if len(y_s_batch) != len(y_t_batch):
                    break
                batch_num = x_s_batch.size(0)

                target_funk = torch.FloatTensor(batch_num, 2).fill_(0.5).cuda()

                # clear previous gradients, compute gradients
                optimizer.zero_grad()

                # move to GPU if available
                x_s = x_s_batch.to(device)
                x_t = x_t_batch.to(device)
                y_s = y_s_batch.to(device)
                y_t = y_t_batch.to(device)

                # compute model output and loss
                # source data
                logits_s = featurenet(x_s)
                outputs_s = classifier(logits_s)
                loss_s = criterion(outputs_s, y_s.long())
                loss_s.backward()

                # target data
                logits_t = featurenet(x_t)
                outputs_t = classifier(logits_t, reverse=True)
                outputs_t = F.softmax(outputs_t)
                prob1 = torch.sum(outputs_t[:, :num_out-1], 1).view(-1, 1)
                prob2 = outputs_t[:, num_out-1].contiguous().view(-1, 1)
                prob = torch.cat([prob1, prob2], 1)

                loss_t = bce_loss(prob, target_funk)
                loss_t.backward()

                # performs updates using calculated gradients
                optimizer.step()
                # clear previous gradients
                optimizer.zero_grad()            

                pbar.update()
        
            # update lr
            if lr_scheduler is not None:
                lr_scheduler.step()

            # validation
            featurenet.eval()
            classifier.eval()
            correct_num = 0
            val_num = 0
            per_class_num = np.zeros((num_out))
            per_class_correct = np.zeros((num_out)).astype(np.float32)
            for step, (x_val_batch, y_val_batch) in enumerate(target_valloader):
                # move to GPU if available
                x_val= x_val_batch.to(device)
                y_val = y_val_batch.to(device)
                batch_size_val = y_val.data.size()[0]

                logits_val = featurenet(x_val)
                outputs_val = classifier(logits_val)

                pre = torch.max(outputs_val.cpu(), 1)[1].numpy()
                y_val = y_val.cpu().numpy()

                correct_num += (pre == y_val).sum() # the number of correct preditions per batch
                val_num += batch_size_val       # the number of predictions per batch

                for i in range(num_out):
                    if i < num_out -1:
                        index = np.where(y_val == i) # known classes
                    else:
                        index = np.where(y_val >= i) # unknown classes 
                                                     # Thanks to @Wang-Dongdong for reporting the bug
                    correct_ind = np.where(pre[index[0]]==i)
                    per_class_correct[i] += float(len(correct_ind[0]))
                    per_class_num[i] += float(len(index[0]))

            per_class_acc = (per_class_correct / per_class_num) * 100.0
            known_acc = (per_class_correct[:-1].sum() / per_class_num[:-1].sum()) * 100.0
            all_acc = (correct_num / val_num) * 100.0

            if all_acc > best_acc:
                    best_acc = all_acc
            logging.info("Epoch: {:>3}/{}, loss_s: {:.4f}, loss_t: {:.4f}, all_acc: {:>6.2f}, known_acc: {:>6.2f}%".format(\
                    epoch+1, args.max_epoch, loss_s, loss_t, all_acc, known_acc))

    logging.info("Best all accuracy: {:.4f}".format(best_acc))

    logging.info("="*10+"Done!"+"="*10)

if __name__ == "__main__":

    args = parse_args()

    # set the logger
    if not os.path.exists("./logs"):
        os.makedirs("./logs")
    setlogger(args.log_file)

    # save the args
    for k, v in args.__dict__.items():
        logging.info("{}: {}".format(k, v))
    
    trainer(args)

================================================
FILE: PreparData/CWRU.py
================================================
"""
@Author: Xiaohan Chen
@Email: cxh_bb@outlook.com
"""

import numpy as np
from scipy.io import loadmat
from PreparData.preprocess import transformation

# datanames in every working conditions
dataname_dict= {0:[97, 109, 122, 135, 174, 189, 201, 213, 226, 238],  # 1797rpm
                1:[98, 110, 123, 136, 175, 190, 202, 214, 227, 239],  # 1772rpm
                2:[99, 111, 124, 137, 176, 191, 203, 215, 228, 240],  # 1750rpm
                3:[100,112, 125, 138, 177, 192, 204, 217, 229, 241]}  # 1730rpm

axis = "_DE_time"
data_length = 1024


def CWRU(datadir, load, labels, window, normalization, backbone, fft):
    """
    loading the hole dataset
    """
    path = datadir + "/CWRU/" + "Drive_end_" + str(load) + "/"
    dataset = {label: [] for label in labels}
    for label in labels:
        fault_type = dataname_dict[load][label]
        if fault_type < 100:
            realaxis = "X0" + str(fault_type) + axis
        else:
            realaxis = "X" + str(fault_type) + axis
        mat_data = loadmat(path+str(fault_type)+".mat")[realaxis]
        start, end = 0, data_length

        # set the endpoint of data sequence
        endpoint = mat_data.shape[0]

        # split the data and transformation
        while end < endpoint:
            sub_data = mat_data[start : end].reshape(-1,)

            sub_data = transformation(sub_data, fft, normalization, backbone)

            dataset[label].append(sub_data)
            start += window
            end += window
        
        dataset[label] = np.array(dataset[label], dtype="float32")

    return dataset

def CWRUloader(args, load, label_set, number="all"):
    """
    args: arguments
    number: the numbers of training samples, "all" or specific numbers (string type)
    """
    dataset = CWRU(args.datadir, load, label_set, args.window, args.normalization, args.backbone, args.fft)

    DATA, LABEL = [], []

    if number == "all":
        counter = []
        for key in dataset.keys():
            counter.append(dataset[key].shape[0])
        datan = min(counter) # choosing the min value as the sample size per class
        for key in dataset.keys():
            LABEL.append(np.tile(key, datan))
            DATA.append(dataset[key][:datan])
    else:
        datan = int(number)
        for key in dataset.keys():
            LABEL.append(np.tile(key, datan))
            DATA.append(dataset[key][:datan])
    
    DATA, LABEL = np.array(DATA, dtype="float32"), np.array(LABEL, dtype="int32")

    return DATA, LABEL

================================================
FILE: PreparData/__init__.py
================================================


================================================
FILE: PreparData/preprocess.py
================================================
import numpy as np

def transformation(sub_data, fft, normalization, backbone):

    if fft:
        sub_data = np.fft.fft(sub_data)
        sub_data = np.abs(sub_data) / len(sub_data)
        sub_data = sub_data[:int(sub_data.shape[0] / 2)].reshape(-1,)                

    if normalization == "0-1":
        sub_data = (sub_data - sub_data.min()) / (sub_data.max() - sub_data.min())
    elif normalization == "mean-std":
        sub_data = (sub_data - sub_data.mean()) / sub_data.std()

    if backbone in ("ResNet1D", "CNN1D"):
        sub_data = sub_data[np.newaxis, :]
    elif backbone == "ResNet2D":
        n = int(np.sqrt(sub_data.shape[0]))
        if fft:
            sub_data = sub_data[:n*n]
        sub_data = np.reshape(sub_data, (n, n))
        sub_data = sub_data[np.newaxis, :]
        sub_data = np.concatenate((sub_data, sub_data, sub_data), axis=0)

    return sub_data

================================================
FILE: README.md
================================================
# Deep transfer learing for fault diagnosis

## :book: 1. Introduction
This repository contains popular deep transfer learning algorithms implemented via PyTorch for cross-load fault diagnosis transfer tasks, including:  

- [x] General supervised learning classification task: traing and test apply the same machines, working conditions and faults.

- [x] *domain adaptation*: the distribution of the source domain data may be different from the target domain data, but the label set of the target domain is the same as the source domain, i.e., $\mathcal{D} _{s}=(X_s,Y_s)$, $\mathcal{D} _{t}=(X_t,Y_t)$, $X_s \ne X_t$, $Y_s = Y_t$.
  - [x] **DDC**: Deep Domain Confusion [[arXiv 2014]](https://arxiv.org/pdf/1412.3474.pdf)
  - [x] **Deep CORAL**: Correlation Alignment for Deep Domain Adaptation [[ECCV 2016]](https://arxiv.org/abs/1607.01719)
  - [x] **DANN**: Unsupervised Domain Adaptation by Backpropagation [[ICML 2015]](http://proceedings.mlr.press/v37/ganin15.pdf)
  - [ ] TODO

- [x] *Open-set domain adaptation*: the distribution of the source domain data may be different from the target domain data. What's more, the target label set contains unknown categories, i.e., $\mathcal{D} _{s}=(X_s,Y_s)$, $\mathcal{D} _{t}=(X_t,Y_t)$, $X_s \ne X_t$, $Y_s \in Y_t$. We refer to their common categories $\mathcal{Y}_s\cap \mathcal{Y}_t$ as the *known classes*, and $\mathcal{Y}_s\setminus \mathcal{Y}_t$ (or $\mathcal{Y}_t\setminus \mathcal{Y}_s$) in the target domain as the *unknown class*.
  - [x] **OSDABP**: Open Set Domain Adaptation by Backpropagation [[ECCV 2018]](http://openaccess.thecvf.com/content_ECCV_2018/papers/Kuniaki_Saito_Adversarial_Open_Set_ECCV_2018_paper.pdf)
  - [ ] TODO

> **Few-shot** learning-based bearing fault diagnosis methods please see: https://github.com/Xiaohan-Chen/few-shot-fault-diagnosis

## :balloon: 2. Citation

For further introductions to transfer learning in bearing fault diagnosis, please read our [paper](https://ieeexplore.ieee.org/document/10042467). And if you find this repository useful and use it in your works, please cite our paper, thank you~:
```
@ARTICLE{10042467,
  author={Chen, Xiaohan and Yang, Rui and Xue, Yihao and Huang, Mengjie and Ferrero, Roberto and Wang, Zidong},
  journal={IEEE Transactions on Instrumentation and Measurement}, 
  title={Deep Transfer Learning for Bearing Fault Diagnosis: A Systematic Review Since 2016}, 
  year={2023},
  volume={72},
  number={},
  pages={1-21},
  doi={10.1109/TIM.2023.3244237}}
```

---
## :wrench: 3. Requirements
- python 3.9.12
- Numpy 1.23.1
- pytorch 1.12.0
- scikit-learn 1.1.1
- torchvision 0.13.0

---
## :handbag: 4. Dataset
Download the bearing dataset from [CWRU Bearing Dataset Center](https://engineering.case.edu/bearingdatacenter/48k-drive-end-bearing-fault-data) and place the `.mat` files in the `./datasets` folder according to the following structure:
```
datasets/
  └── CWRU/
      ├── Drive_end_0/
      │   └── 97.mat 109.mat 122.mat 135.mat 174.mat 189.mat 201.mat 213.mat 226.mat 238.mat
      ├── Drive_end_1/
      │   └── 98.mat 110.mat 123.mat 136.mat 175.mat 190.mat 202.mat 214.mat 227.mat  239.mat
      ├── Drive_end_2/
      │   └── 99.mat 111.mat 124.mat 137.mat 176.mat 191.mat 203.mat 215.mat 228.mat 240.mat
      └── Drive_end_3/
          └── 100.mat 112.mat 125.mat 138.mat 177.mat 192.mat 204.mat 217.mat 229.mat 241.mat
```

---
## :pencil: 5. Usage
> **NOTE**: When using pre-trained models to initialise the backbone and classifier in transfer learning tasks, run classification tasks first to generate corresponding checkpoints.

Four typical neural networks are implemented in this repository, including MLP, 1D CNN, 1D ResNet18, and 2D ResNet18(torchvision package). More details can be found in the `./Backbone` folder.

**General Supervised Learning Classification:**
- Train and test the model on the same machines, working conditions and faults. Use the following commands:
```python
python3 classification.py --datadir './datasets' --max_epoch 100
```

**Transfer Learning:**
- If using the DDC transfer learning method, use the following commands:
```python
python3 DDC.py --datadir './datasets' --backbone "CNN1D" --pretrained False --kernel 'Linear'
```
- If using the DeepCORAL transfer learning method, use the following commands:
```python
python3 DDC.py --datadir './datasets' --backbone "CNN1D" --pretrained False --kernel 'CORAL'
```
- If using the DANN transfer learning method, use following commands:
```python
python3 DANN.py --backbone "CNN1D"
```
**Open Set Domain Adaptation:**
- The target domain contains unknow classes, use the following commands:
```python
python3 OSDABP.py
```
---
## :flashlight: 6. Results
> The following results do not represent the best results.

**General Classification task:**  
Dataset: CWRU  
Load: 3  
Label set: [0,1,2,3,4,5,6,7,8,9]  

|                   | MLPNet | CNN1D | ResNet1D | ResNet2D |
| :---------------: | :----: | :---: | :------: | :------: |
| acc (time domain) | 93.95  | 97.70 |  99.58   |  98.02   |
| acc (freq domain) | 99.95  | 99.44 |  100.0   |  99.96   |

**Transfer Learning:**  
Dataset: CWRU  
Source load: 3  
Target Load: 2  
Label set: [0,1,2,3,4,5,6,7,8,9]  
Pre-trained model: True  

Time domain:  
|                     | MLPNet | CNN1D | ResNet1D | ResNet2D |
| :-----------------: | :----: | :---: | :------: | :------: |
| DDC (linear kernel) | 75.47  | 85.53 |  91.79   |  91.32   |
|      DeepCORAL      | 82.33  | 88.23 |  93.88   |  90.84   |
|        DANN         | 87.68  | 94.77 |  98.88   |  93.95   |

Frequency domain
|           | MLPNet | CNN1D | ResNet1D | ResNet2D |
| :-------: | :----: | :---: | :------: | :------: |
| DeepCORAL | 98.65  | 98.22 |  99.75   |  99.31   |
|   DANN    | 99.38  | 98.74 |  99.89   |  99.47   |

**Open Set Domain Adaptation**  
- *OSDABP*
Dataset: CWRU  
Source load: 3  
Target Load: 2  
Source label set: [0,1,2,3,4,5]  
Target label set: [0,1,2,3,4,5,6,7,8,9]  
Pre-trained model: True  

|  Label   |   0   |   1   |   2   |   3   |   4   |   5   |  unk  | All   | Only known |
| :------: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | ----- | ---------- |
|  MLPNet  | 99.83 | 95.96 | 59.76 | 76.10 | 19.85 | 96.58 | 59.21 | 70.21 | 75.99      |
|  CNN1D   | 100.0 | 94.95 | 94.47 | 99.08 | 47.31 | 74.32 | 26.36 | 61.75 | 85.35      |
| ResNet1D | 100.0 | 100.0 | 80.14 | 100.0 | 43.32 | 93.49 | 45.22 | 70.04 | 86.58      |
| ResNet2D | 100.0 | 100.0 | 94.82 | 100.0 | 18.55 | 98.12 | 53.42 | 72.95 | 85.96      |


---
## :camping: 7. See also
- Multi-scale CNN and LSTM bearing fault diagnosis [[paper](https://link.springer.com/article/10.1007/s10845-020-01600-2)][[GitHub](https://github.com/Xiaohan-Chen/baer_fault_diagnosis)]
- TFPred self-supervised learning for few labeled fault diagnosis [[Paper](https://www.sciencedirect.com/science/article/pii/S0967066124000601)][[GitHub](https://github.com/Xiaohan-Chen/TFPred)]

---
## :globe_with_meridians: 8. Acknowledgement

```
@article{zhao2021applications,
  title={Applications of Unsupervised Deep Transfer Learning to Intelligent Fault Diagnosis: A Survey and Comparative Study},
  author={Zhibin Zhao and Qiyang Zhang and Xiaolei Yu and Chuang Sun and Shibin Wang and Ruqiang Yan and Xuefeng Chen},
  journal={IEEE Transactions on Instrumentation and Measurement},
  year={2021}
}
```

I would like to thank the following person for contributing to this repository: [@Wang-Dongdong](https://github.com/Wang-Dongdong),[@zhuting233](https://github.com/zhuting233)

================================================
FILE: Utils/__init__.py
================================================


================================================
FILE: Utils/logger.py
================================================
import logging

def setlogger(path):

    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    formatter = logging.Formatter("%(asctime)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S")

    consoleHandler = logging.StreamHandler()
    fileHandler = logging.FileHandler(filename=path)

    consoleHandler.setFormatter(formatter)
    fileHandler.setFormatter(formatter)

    logger.addHandler(consoleHandler)
    logger.addHandler(fileHandler)

================================================
FILE: Utils/utils.py
================================================
import logging
import torch
import pickle
import torch.optim as optim
from sklearn.model_selection import train_test_split

def accuracy(outputs, labels):
    """
    Compute the accuracy
    outputs, labels: (tensor)
    return: (float) accuracy in [0, 100]
    """
    pre = torch.max(outputs.cpu(), 1)[1].numpy()
    y = labels.data.cpu().numpy()
    acc = ((pre == y).sum() / len(y)) * 100
    return acc

def save_log(obj, path):
    with open(path, "wb") as f:
        pickle.dump(obj, f)

def read_pkl(path):
    with open(path, 'rb') as f:
        data = pickle.load(f)
    return data

def save_model(model, args):
    if not args.fft:
        torch.save(model.state_dict(), "./checkpoints/{}_checkpoint.tar".format(args.backbone))
    else:
        torch.save(model.state_dict(), "./checkpoints/{}FFT_checkpoint.tar".format(args.backbone))

def DataSplite(args, data, label):
    """
    split the data and lebel and transform the narray type to tensor type
    """
    data_train, data_val, label_train, label_val = train_test_split(data, label, test_size = args.val_rat, random_state=args.seed)
    data_val, data_test, label_val, label_test = train_test_split(data_val, label_val, test_size= args.test_rat)

    # numpy to tensor
    data_train = torch.from_numpy(data_train).float()
    data_val = torch.from_numpy(data_val).float()
    data_test = torch.from_numpy(data_test).float()
    label_train = torch.from_numpy(label_train).float()
    label_val = torch.from_numpy(label_val).float()
    label_test = torch.from_numpy(label_test).float()

    # logging the data shape
    logging.info("training data/label shape: {},{}".format(data_train.size(), label_train.size()))
    logging.info("validation data/label shape: {},{}".format(data_val.size(), label_val.size()))
    logging.info("test data/label shape: {},{}".format(data_test.size(), label_test.size()))

    # build the dataloader
    train = torch.utils.data.TensorDataset(data_train, label_train)
    val = torch.utils.data.TensorDataset(data_val, label_val)
    test = torch.utils.data.TensorDataset(data_test, label_test)

    train_loader = torch.utils.data.DataLoader(train, batch_size=args.batch_size, \
        shuffle=True, num_workers=args.num_workers)

    val_loader = torch.utils.data.DataLoader(val, batch_size=args.batch_size, \
        shuffle=False, num_workers=args.num_workers)
    test_loader = torch.utils.data.DataLoader(test, batch_size=args.batch_size, \
        shuffle=False, num_workers=args.num_workers)

    return train_loader, val_loader, test_loader

def optimizer(args, parameter_list):
    # define optimizer
    if args.optimizer == "sgd":
        optimizer = optim.SGD(parameter_list, lr=args.lr, momentum=0.9, weight_decay=5e-4)
    elif args.optimizer == "adam":
        optimizer = optim.Adam(parameter_list, lr=args.lr)
    else:
        raise Exception("optimizer not implement")

    # Define the learning rate decay
    if args.lr_scheduler == 'step':
        steps = [int(step) for step in args.steps.split(',')]
        lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, steps, gamma=args.gamma)
    elif args.lr_scheduler == 'exp':
        lr_scheduler = optim.lr_scheduler.ExponentialLR(optimizer, args.gamma)
    elif args.lr_scheduler == 'stepLR':
        steps = int(args.steps.split(",")[0])
        lr_scheduler = optim.lr_scheduler.StepLR(optimizer, steps, args.gamma)
    elif args.lr_scheduler == 'cos':
        lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, 20, 0)
    elif args.lr_scheduler == 'fix':
        lr_scheduler = None
    else:
        raise Exception("lr schedule not implement")

    return optimizer, lr_scheduler


================================================
FILE: classification.py
================================================
import argparse
import os
import numpy as np
from Utils.logger import setlogger

from turtle import forward
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models

from Backbone import ResNet1D, MLPNet, CNN1D
from PreparData.CWRU import CWRUloader
import Utils.utils as utils

from tqdm import *
import warnings
import logging

# ===== Define argments =====
def parse_args():
    parser = argparse.ArgumentParser(description='classification task')

    # task setting
    parser.add_argument("--log_file", type=str, default="./logs/classification.log", help="log file path")

    # dataset information
    parser.add_argument("--datadir", type=str, default="./datasets", help="data directory")
    parser.add_argument("--load", type=int, default=3, help="working condition")
    parser.add_argument("--label_set", type=list, default=[0,1,2,3,4,5,6,7,8,9], help="label set")
    parser.add_argument("--val_rat", type=float, default=0.3, help="training-validation rate")
    parser.add_argument("--test_rat", type=float, default=0.5, help="validation-test rate")
    parser.add_argument("--seed", type=int, default="29")

    # pre-processing
    parser.add_argument("--fft", type=bool, default=False, help="FFT preprocessing")
    parser.add_argument("--window", type=int, default=128, help="time window, if not augment data, window=1024")
    parser.add_argument("--normalization", type=str, default="0-1", choices=["None", "0-1", "mean-std"], help="normalization option")
    parser.add_argument("--savemodel", type=bool, default=False, help="whether save pre-trained model in the classification task")
    parser.add_argument("--pretrained", type=bool, default=False, help="whether use pre-trained model in transfer learning tasks")

    # backbone
    parser.add_argument("--backbone", type=str, default="ResNet1D", choices=["ResNet1D", "ResNet2D", "MLPNet", "CNN1D"])
    # if   backbone in ("ResNet1D", "CNN1D"),  data shape: (batch size, 1, 1024)
    # elif backbone == "ResNet2D",             data shape: (batch size, 3, 32, 32)
    # elif backbone == "MLPNet",               data shape: (batch size, 1024)


    # optimization & training
    parser.add_argument("--num_workers", type=int, default=0, help="the number of dataloader workers")
    parser.add_argument("--batch_size", type=int, default=256)
    parser.add_argument("--max_epoch", type=int, default=100)
    parser.add_argument("--lr", type=float, default=1e-3, help="learning rate")
    parser.add_argument('--lr_scheduler', type=str, default='stepLR', choices=['step', 'exp', 'stepLR', 'fix'], help='the learning rate schedule')
    parser.add_argument('--gamma', type=float, default=0.8, help='learning rate scheduler parameter for step and exp')
    parser.add_argument('--steps', type=str, default='30, 120', help='the learning rate decay for step and stepLR')
    parser.add_argument("--optimizer", type=str, default="adam", choices=["adam", "sgd"])


    args = parser.parse_args()
    return args

# ===== Build Model =====
class FeatureNet(nn.Module):
    def __init__(self, args):
        super(FeatureNet, self).__init__()
        if args.backbone == "ResNet1D":
            self.feature_net = ResNet1D.resnet18()
        elif args.backbone == "ResNet2D":
            self.model_ft = models.resnet18(pretrained=True)
            self.bottleneck = nn.Sequential(nn.Linear(self.model_ft.fc.out_features, 512), nn.ReLU(), nn.Dropout(0.5))
            self.feature_net = nn.Sequential(self.model_ft, self.bottleneck)
        elif args.backbone == "MLPNet":
            if args.fft:
                self.feature_net = MLPNet.MLPNet(num_in=512)
            else:
                self.feature_net = MLPNet.MLPNet()
        elif args.backbone == "CNN1D":
            self.feature_net = CNN1D.CNN1D()
        else:
            raise Exception("model not implement")

    def forward(self, x):
        logits = self.feature_net(x)

        return logits

class Classifier(nn.Module):
    def __init__(self, args, num_out=10):
        super(Classifier, self).__init__()
        if args.backbone in ("ResNet1D", "ResNet2D"):
            self.classifier = nn.Sequential(nn.Linear(512,num_out, nn.Dropout(0.5)))
        if args.backbone in ("MLPNet", "CNN1D"):
            self.classifier = nn.Sequential(nn.Linear(64,num_out, nn.Dropout(0.5)))

    def forward(self, logits):
        outputs = self.classifier(logits)

        return outputs

# ===== Load Data =====
def loaddata(args):
    data, label = CWRUloader(args, args.load, args.label_set)
    data, label = np.concatenate(data, axis=0), np.concatenate(label, axis=0)
    
    train_loader, val_loader, test_laoder = utils.DataSplite(args, data, label)
    
    return train_loader, val_loader, test_laoder

# ===== Test the Model =====
def tester(featurenet, classifier, dataloader):
    featurenet.eval()
    classifier.eval()
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    correct_num, total_num = 0, 0
    for i, (x_batch, y_batch) in enumerate(dataloader):
        x_batch, y_batch = x_batch.to(device), y_batch.to(device)

        # compute model cotput and loss
        logtis_batch = featurenet(x_batch)
        output_batch = classifier(logtis_batch)

        pre = torch.max(output_batch.cpu(), 1)[1].numpy()
        y = y_batch.cpu().numpy()
        correct_num += (pre == y).sum()
        total_num += len(y)
    accuracy = (correct_num / total_num) * 100.0
    return accuracy


# ===== Train the Model =====
def trainer(args):
    # Consider the gpu or cpu condition
    if torch.cuda.is_available():
        device = torch.device("cuda")
        device_count = torch.cuda.device_count()
        logging.info('using {} gpus'.format(device_count))
        assert args.batch_size % device_count == 0, "batch size should be divided by device count"
    else:
        warnings.warn("gpu is not available")
        device = torch.device("cpu")
        device_count = 1
        logging.info('using {} cpu'.format(device_count))
    
    # load the dataset
    trainloader, valloader, testloader = loaddata(args)

    # load the model
    featurenet = FeatureNet(args)
    classifier = Classifier(args, num_out=len(args.label_set))

    parameter_list = [{"params": featurenet.parameters(), "lr": args.lr},
                       {"params": classifier.parameters(), "lr": args.lr}]

    # Define optimizer and learning rate decay
    optimizer, lr_scheduler = utils.optimizer(args, parameter_list)

    # define loss function
    loss_fn = nn.CrossEntropyLoss()
    featurenet.to(device)
    classifier.to(device)
    # train
    best_acc = 0.0
    meters = {"acc_train": [], "acc_val": []}

    for epoch in range(args.max_epoch):
        featurenet.train()
        classifier.train()
        with tqdm(total=len(trainloader), leave=False) as pbar:
            for i, (x_batch, y_batch) in enumerate(trainloader):
                # move to GPU if available
                x_batch, y_batch = x_batch.to(device), y_batch.to(device)

                # compute model cotput and loss
                logtis_batch = featurenet(x_batch)
                output_batch = classifier(logtis_batch)
                loss = loss_fn(output_batch, y_batch.long())

                # clear previous gradients, compute gradients
                optimizer.zero_grad()
                loss.backward()

                # performs updates using calculated gradients
                optimizer.step()

                # evaluate
                # training accuracy
                train_acc = utils.accuracy(output_batch, y_batch)

                pbar.update()
        
        # update lr
        if lr_scheduler is not None:
            lr_scheduler.step()

        val_acc = tester(featurenet, classifier, valloader)
        if val_acc > best_acc:
            best_acc = val_acc
            if args.savemodel:
                utils.save_model(featurenet, args)
        
        logging.info("Epoch: {:>3}/{}, loss: {:.4f}, train_acc: {:>6.2f}%, val_acc: {:>6.2f}%".format(\
                epoch+1, args.max_epoch, loss, train_acc, val_acc))
        meters["acc_train"].append(train_acc)
        meters["acc_val"].append(val_acc)

    logging.info("Best accuracy: {:.4f}%".format(best_acc))
    utils.save_log(meters, "./logs/cls_{}_{}_meters.pkl".format(args.backbone, args.max_epoch))

    logging.info("="*15+"Done!"+"="*15)

if __name__ == "__main__":

    args = parse_args()

    # set the logger
    if not os.path.exists("./logs"):
        os.makedirs("./logs")
    setlogger(args.log_file)

    # save the pre-trained model
    if not os.path.exists("./checkpoints"):
        os.makedirs("./checkpoints")

    # save the args
    for k, v in args.__dict__.items():
        logging.info("{}: {}".format(k, v))
    
    trainer(args)

================================================
FILE: loss/CORAL.py
================================================
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import torch


"""
Created on Saturday Feb 25 2020

@authors: Alan Preciado, Santosh Muthireddy
"""
def CORAL_loss(source, target):
	"""
	From the paper, the vectors that compose Ds and Dt are D-dimensional vectors
	:param source: torch tensor: source data (Ds) with dimensions DxNs
	:param target: torch tensor: target data (Dt) with dimensons DxNt
	"""

	d = source.size(1) # d-dimensional vectors (same for source, target)

	source_covariance = compute_covariance(source)
	target_covariance = compute_covariance(target)

	# take Frobenius norm (https://pytorch.org/docs/stable/torch.html)
	loss = torch.norm(torch.mul((source_covariance-target_covariance),
								(source_covariance-target_covariance)), p="fro")

	# loss = torch.norm(torch.mm((source_covariance-target_covariance),
	# 							(source_covariance-target_covariance)), p="fro")

	loss = loss/(4*d*d)

	return loss


def compute_covariance(data):
	"""
	Compute covariance matrix for given dataset as shown in paper (eqs 2 and 3).
	:param data: torch tensor: input source/target data
	"""

	# data dimensions: nxd (this for Ns or Nt)
	n = data.size(0) # get batch size
	#print("compute covariance bath size n:", n)

  # check gpu or cpu support
	if data.is_cuda:
		device = torch.device("cuda")
	else:
		device = torch.device("cpu")

	# proper matrix multiplication for right side of equation (2)
	ones_vector = torch.ones(n).resize(1, n).to(device=device) # 1xN dimensional vector (transposed)
	one_onto_D = torch.mm(ones_vector, data)
	mult_right_terms = torch.mm(one_onto_D.t(), one_onto_D)
	mult_right_terms = torch.div(mult_right_terms, n) # element-wise divison

	# matrix multiplication for left side of equation (2)
	mult_left_terms = torch.mm(data.t(), data)

	covariance_matrix= 1/(n-1) * torch.add(mult_left_terms,-1*(mult_right_terms))

	return covariance_matrix


================================================
FILE: loss/MKMMD.py
================================================
# source code: https://github.com/ZhaoZhibin/UDTL (it seems the link is not the original source of the MK-MMD code)
# Params:
#       source: source data
#       target: target data
#       kernel_mul:
#       kernel_num: the number of kernel
#       fix_sigma:

import torch

def guassian_kernel(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
    n_samples = int(source.size()[0])+int(target.size()[0])  # the number of source+target
    total = torch.cat([source, target], dim=0)
    total0 = total.unsqueeze(0).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))
    total1 = total.unsqueeze(1).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))
    L2_distance = ((total0-total1)**2).sum(2)
    if fix_sigma:
        bandwidth = fix_sigma
    else:
        bandwidth = torch.sum(L2_distance.data) / (n_samples**2-n_samples)
    bandwidth /= kernel_mul ** (kernel_num // 2)
    bandwidth_list = [bandwidth * (kernel_mul**i) for i in range(kernel_num)]
    kernel_val = [torch.exp(-L2_distance / bandwidth_temp) for bandwidth_temp in bandwidth_list]
    return sum(kernel_val)#/len(kernel_val)


def MKMMD(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
    batch_size = int(source.size()[0])
    kernels = guassian_kernel(source, target,
        kernel_mul=kernel_mul, kernel_num=kernel_num, fix_sigma=fix_sigma)
    XX = kernels[:batch_size, :batch_size]
    YY = kernels[batch_size:, batch_size:]
    XY = kernels[:batch_size, batch_size:]
    YX = kernels[batch_size:, :batch_size]
    loss = torch.mean(XX + YY - XY - YX)
    return loss


================================================
FILE: loss/MMDLinear.py
================================================
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import torch


"""
Created on Sunday 22 Mar 2020

@authors: Alan Preciado, Santosh Muthireddy
"""
def MMDLinear(source_activation, target_activation):
	"""
	From the paper, the loss used is the maximum mean discrepancy (MMD)
	:param source: torch tensor: source data (Ds) with dimensions DxNs
	:param target: torch tensor: target data (Dt) with dimensons DxNt
	"""

	diff_domains = source_activation - target_activation
	loss = torch.mean(torch.mm(diff_domains, torch.transpose(diff_domains, 0, 1)))

	return loss


================================================
FILE: loss/__init__.py
================================================
Download .txt
gitextract_75kh58ds/

├── Backbone/
│   ├── CNN1D.py
│   ├── MLPNet.py
│   └── ResNet1D.py
├── DANN.py
├── DDC.py
├── OSDABP.py
├── PreparData/
│   ├── CWRU.py
│   ├── __init__.py
│   └── preprocess.py
├── README.md
├── Utils/
│   ├── __init__.py
│   ├── logger.py
│   └── utils.py
├── classification.py
└── loss/
    ├── CORAL.py
    ├── MKMMD.py
    ├── MMDLinear.py
    └── __init__.py
Download .txt
SYMBOL INDEX (88 symbols across 14 files)

FILE: Backbone/CNN1D.py
  class CNN1D (line 3) | class CNN1D(nn.Module):
    method __init__ (line 4) | def __init__(self, num_out = 10):
    method forward (line 27) | def forward(self, x):

FILE: Backbone/MLPNet.py
  class MLPNet (line 3) | class MLPNet(nn.Module):
    method __init__ (line 4) | def __init__(self, num_in = 1024, num_out = 10):
    method forward (line 28) | def forward(self, x):

FILE: Backbone/ResNet1D.py
  function conv3x1 (line 14) | def conv3x1(in_planes, out_planes, stride=1):
  function conv1x1 (line 19) | def conv1x1(in_planes, out_planes, stride=1):
  class BasicBlock (line 24) | class BasicBlock(nn.Module):
    method __init__ (line 27) | def __init__(self, inplanes, planes, stride=1, downsample=None):
    method forward (line 37) | def forward(self, x):
  class Bottleneck (line 56) | class Bottleneck(nn.Module):
    method __init__ (line 60) | def __init__(self, inplanes, planes, stride=1, downsample=None):
    method forward (line 72) | def forward(self, x):
  class ResNet (line 95) | class ResNet(nn.Module):
    method __init__ (line 97) | def __init__(self, block, layers, in_channel=1, out_channel=10, zero_i...
    method _make_layer (line 129) | def _make_layer(self, block, planes, blocks, stride=1):
    method forward (line 145) | def forward(self, x):
  function resnet18 (line 162) | def resnet18(pretrained=False, **kwargs):
  class resnet18_features (line 175) | class resnet18_features(nn.Module):
    method __init__ (line 176) | def __init__(self, pretrained=False):
    method forward (line 181) | def forward(self, x):
    method output_num (line 185) | def output_num(self):

FILE: DANN.py
  function parse_args (line 22) | def parse_args():
  class FeatureNet (line 69) | class FeatureNet(nn.Module):
    method __init__ (line 70) | def __init__(self, args):
    method forward (line 88) | def forward(self, x):
  class Classifier (line 93) | class Classifier(nn.Module):
    method __init__ (line 94) | def __init__(self, args, num_out=10):
    method forward (line 101) | def forward(self, logits):
  function calc_coeff (line 107) | def calc_coeff(iter_num, high=1.0, low=0.0, alpha=10.0, max_iter=10000.0):
  function grl_hook (line 112) | def grl_hook(coeff):
  class Discriminator (line 117) | class Discriminator(nn.Module):
    method __init__ (line 118) | def __init__(self, args, num_out = 1, max_iter=10000.0, trade_off_adve...
    method forward (line 145) | def forward(self, x):
  function loaddata (line 162) | def loaddata(args):
  function tester (line 179) | def tester(featurenet, classifier, dataloader):
  function trainer (line 199) | def trainer(args):

FILE: DDC.py
  function parse_args (line 22) | def parse_args():
  class FeatureNet (line 70) | class FeatureNet(nn.Module):
    method __init__ (line 71) | def __init__(self, args):
    method forward (line 89) | def forward(self, x):
  class Classifier (line 94) | class Classifier(nn.Module):
    method __init__ (line 95) | def __init__(self, args, num_out=10):
    method forward (line 102) | def forward(self, logits):
  function loaddata (line 108) | def loaddata(args):
  function tester (line 129) | def tester(featurenet, classifier, dataloader):
  function trainer (line 150) | def trainer(args):

FILE: OSDABP.py
  function parse_args (line 21) | def parse_args():
  class GradReverse (line 68) | class GradReverse(Function):
    method forward (line 70) | def forward(ctx, x):
    method backward (line 74) | def backward(ctx, grad_output):
  function grad_reverse (line 78) | def grad_reverse(x):
  class FeatureNet (line 82) | class FeatureNet(nn.Module):
    method __init__ (line 83) | def __init__(self, args):
    method forward (line 98) | def forward(self, x):
  class Classifier (line 103) | class Classifier(nn.Module):
    method __init__ (line 104) | def __init__(self, args, num_out=10):
    method forward (line 111) | def forward(self, logits, reverse = False):
  function loaddata (line 119) | def loaddata(args):
  function bce_loss (line 136) | def bce_loss(output, target):
  function tester (line 144) | def tester(featurenet, classifier, dataloader):
  function trainer (line 165) | def trainer(args):

FILE: PreparData/CWRU.py
  function CWRU (line 20) | def CWRU(datadir, load, labels, window, normalization, backbone, fft):
  function CWRUloader (line 52) | def CWRUloader(args, load, label_set, number="all"):

FILE: PreparData/preprocess.py
  function transformation (line 3) | def transformation(sub_data, fft, normalization, backbone):

FILE: Utils/logger.py
  function setlogger (line 3) | def setlogger(path):

FILE: Utils/utils.py
  function accuracy (line 7) | def accuracy(outputs, labels):
  function save_log (line 18) | def save_log(obj, path):
  function read_pkl (line 22) | def read_pkl(path):
  function save_model (line 27) | def save_model(model, args):
  function DataSplite (line 33) | def DataSplite(args, data, label):
  function optimizer (line 68) | def optimizer(args, parameter_list):

FILE: classification.py
  function parse_args (line 21) | def parse_args():
  class FeatureNet (line 64) | class FeatureNet(nn.Module):
    method __init__ (line 65) | def __init__(self, args):
    method forward (line 83) | def forward(self, x):
  class Classifier (line 88) | class Classifier(nn.Module):
    method __init__ (line 89) | def __init__(self, args, num_out=10):
    method forward (line 96) | def forward(self, logits):
  function loaddata (line 102) | def loaddata(args):
  function tester (line 111) | def tester(featurenet, classifier, dataloader):
  function trainer (line 132) | def trainer(args):

FILE: loss/CORAL.py
  function CORAL_loss (line 12) | def CORAL_loss(source, target):
  function compute_covariance (line 36) | def compute_covariance(data):

FILE: loss/MKMMD.py
  function guassian_kernel (line 11) | def guassian_kernel(source, target, kernel_mul=2.0, kernel_num=5, fix_si...
  function MKMMD (line 27) | def MKMMD(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):

FILE: loss/MMDLinear.py
  function MMDLinear (line 12) | def MMDLinear(source_activation, target_activation):
Condensed preview — 18 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (78K chars).
[
  {
    "path": "Backbone/CNN1D.py",
    "chars": 1069,
    "preview": "import torch.nn as nn\n\nclass CNN1D(nn.Module):\n    def __init__(self, num_out = 10):\n        super(CNN1D, self).__init__"
  },
  {
    "path": "Backbone/MLPNet.py",
    "chars": 865,
    "preview": "import torch.nn as nn\n\nclass MLPNet(nn.Module):\n    def __init__(self, num_in = 1024, num_out = 10):\n        super(MLPNe"
  },
  {
    "path": "Backbone/ResNet1D.py",
    "chars": 6119,
    "preview": "# one-dimentional ResNet source code reference: https://github.com/ZhaoZhibin/UDTL/blob/master/models/resnet18_1d.py\n\nim"
  },
  {
    "path": "DANN.py",
    "chars": 13742,
    "preview": "import argparse\nimport os\nimport numpy as np\nfrom Utils.logger import setlogger\n\nfrom turtle import forward\nimport torch"
  },
  {
    "path": "DDC.py",
    "chars": 11751,
    "preview": "import argparse\nimport os\nimport numpy as np\nfrom Utils.logger import setlogger\n\nfrom turtle import forward\nimport torch"
  },
  {
    "path": "OSDABP.py",
    "chars": 13011,
    "preview": "import argparse\nimport os\nimport numpy as np\nfrom Utils.logger import setlogger\n\nimport torch\nimport torch.nn as nn\nimpo"
  },
  {
    "path": "PreparData/CWRU.py",
    "chars": 2518,
    "preview": "\"\"\"\n@Author: Xiaohan Chen\n@Email: cxh_bb@outlook.com\n\"\"\"\n\nimport numpy as np\nfrom scipy.io import loadmat\nfrom PreparDat"
  },
  {
    "path": "PreparData/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "PreparData/preprocess.py",
    "chars": 891,
    "preview": "import numpy as np\n\ndef transformation(sub_data, fft, normalization, backbone):\n\n    if fft:\n        sub_data = np.fft.f"
  },
  {
    "path": "README.md",
    "chars": 7546,
    "preview": "# Deep transfer learing for fault diagnosis\n\n## :book: 1. Introduction\nThis repository contains popular deep transfer le"
  },
  {
    "path": "Utils/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "Utils/logger.py",
    "chars": 453,
    "preview": "import logging\n\ndef setlogger(path):\n\n    logger = logging.getLogger()\n    logger.setLevel(logging.INFO)\n    formatter ="
  },
  {
    "path": "Utils/utils.py",
    "chars": 3684,
    "preview": "import logging\nimport torch\nimport pickle\nimport torch.optim as optim\nfrom sklearn.model_selection import train_test_spl"
  },
  {
    "path": "classification.py",
    "chars": 8844,
    "preview": "import argparse\nimport os\nimport numpy as np\nfrom Utils.logger import setlogger\n\nfrom turtle import forward\nimport torch"
  },
  {
    "path": "loss/CORAL.py",
    "chars": 1889,
    "preview": "#!/usr/bin/env python3\n# -*- coding: utf-8 -*-\n\nimport torch\n\n\n\"\"\"\nCreated on Saturday Feb 25 2020\n\n@authors: Alan Preci"
  },
  {
    "path": "loss/MKMMD.py",
    "chars": 1607,
    "preview": "# source code: https://github.com/ZhaoZhibin/UDTL (it seems the link is not the original source of the MK-MMD code)\n# Pa"
  },
  {
    "path": "loss/MMDLinear.py",
    "chars": 562,
    "preview": "#!/usr/bin/env python3\n# -*- coding: utf-8 -*-\n\nimport torch\n\n\n\"\"\"\nCreated on Sunday 22 Mar 2020\n\n@authors: Alan Preciad"
  },
  {
    "path": "loss/__init__.py",
    "chars": 0,
    "preview": ""
  }
]

About this extraction

This page contains the full source code of the Xiaohan-Chen/transfer-learning-fault-diagnosis-pytorch GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 18 files (72.8 KB), approximately 19.5k tokens, and a symbol index with 88 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!