Full Code of carrierlxk/COSNet for AI

master 549109db0d60 cached
17 files
131.8 KB
34.6k tokens
119 symbols
1 requests
Download .txt
Repository: carrierlxk/COSNet
Branch: master
Commit: 549109db0d60
Files: 17
Total size: 131.8 KB

Directory structure:
gitextract_zvfyan4m/

├── README.md
├── dataloaders/
│   ├── PairwiseImg_test.py
│   ├── PairwiseImg_video.py
│   ├── PairwiseImg_video_test_try.py
│   ├── PairwiseImg_video_try.py
│   └── r
├── deeplab/
│   ├── __init__.py
│   ├── e
│   ├── siamese_model_conf.py
│   ├── siamese_model_conf_try_single.py
│   └── utils.py
├── densecrf_apply_cvpr2019.py
├── pretrained/
│   └── deep_labv3/
│       └── readme.md
├── test_coattention_conf.py
├── test_iteration_conf_group.py
├── train_iteration_conf.py
└── train_iteration_conf_group.py

================================================
FILE CONTENTS
================================================

================================================
FILE: README.md
================================================
# COSNet
Code for CVPR 2019 paper: 

[See More, Know More: Unsupervised Video Object Segmentation with
Co-Attention Siamese Networks](http://openaccess.thecvf.com/content_CVPR_2019/papers/Lu_See_More_Know_More_Unsupervised_Video_Object_Segmentation_With_Co-Attention_CVPR_2019_paper.pdf)

[Xiankai Lu](https://sites.google.com/site/xiankailu111/), [Wenguan Wang](https://sites.google.com/view/wenguanwang), Chao Ma, Jianbing Shen, Ling Shao, Fatih Porikli

##

![](../master/framework.png)

- - -
:new:

Our group co-attention achieves a further performance gain (81.1 mean J on DAVIS-16 dataset), related codes have also been released.

The pre-trained model, testing and training code:

### Quick Start

#### Testing

1. Install pytorch (version:1.0.1).

2. Download the pretrained model. Run 'test_coattention_conf.py' and change the davis dataset path, pretrainde model path and result path.

3. Run command: python test_coattention_conf.py --dataset davis --gpus 0

4. Post CRF processing code comes from: https://github.com/lucasb-eyer/pydensecrf. 

The pretrained weight can be download from [GoogleDrive](https://drive.google.com/open?id=14ya3ZkneeHsegCgDrvkuFtGoAfVRgErz) or [BaiduPan](https://pan.baidu.com/s/16oFzRmn4Meuq83fCYr4boQ), pass code: xwup.

The segmentation results on DAVIS, FBMS and Youtube-objects can be download from DAVIS_benchmark(https://davischallenge.org/davis2016/soa_compare.html) or
[GoogleDrive](https://drive.google.com/open?id=1JRPc2kZmzx0b7WLjxTPD-kdgFdXh5gBq) or [BaiduPan](https://pan.baidu.com/s/11n7zAt3Lo2P3-42M2lsw6Q), pass code: q37f.

The youtube-objects dataset can be downloaded from [here](http://calvin-vision.net/datasets/youtube-objects-dataset/) and annotation can be found [here](http://vision.cs.utexas.edu/projects/videoseg/data_download_register.html).

The FBMS dataset can be downloaded from [here](https://lmb.informatik.uni-freiburg.de/resources/datasets/moseg.en.html).
#### Training

1. Download all the training datasets, including MARA10K and DUT saliency datasets. Create a folder called images and put these two datasets into the folder. 

2. Download the deeplabv3 model from [GoogleDrive](https://drive.google.com/open?id=1hy0-BAEestT9H4a3Sv78xrHrzmZga9mj). Put it into the folder pretrained/deep_labv3.

3. Change the video path, image path and deeplabv3 path in train_iteration_conf.py.  Create two txt files which store the saliency dataset name and DAVIS16 training sequences name. Change the txt path in PairwiseImg_video.py.

4. Run command: python train_iteration_conf.py --dataset davis --gpus 0,1

### Citation

If you find the code and dataset useful in your research, please consider citing:
```
@InProceedings{Lu_2019_CVPR,  
author = {Lu, Xiankai and Wang, Wenguan and Ma, Chao and Shen, Jianbing and Shao, Ling and Porikli, Fatih},  
title = {See More, Know More: Unsupervised Video Object Segmentation With Co-Attention Siamese Networks},  
booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},  
year = {2019}  
}
@article{lu2020_pami,
  title={Zero-Shot Video Object Segmentation with Co-Attention Siamese Networks},
  author={Lu, Xiankai and Wang, Wenguan and Shen, Jianbing and Crandall, David and Luo, Jiebo},
  journal={IEEE Transactions on Pattern Analysis and Machine Intelligence},
  year={2020},
  publisher={IEEE}
}
```
### Other related projects/papers:
[Saliency-Aware Geodesic Video Object Segmentation (CVPR15)](https://github.com/wenguanwang/saliencysegment)

[Learning Unsupervised Video Primary Object Segmentation through Visual Attention (CVPR19)](https://github.com/wenguanwang/AGS)

Any comments, please email: carrierlxk@gmail.com


================================================
FILE: dataloaders/PairwiseImg_test.py
================================================
# -*- coding: utf-8 -*-
"""
Created on Wed Sep 12 11:39:54 2018

@author: carri
"""
# for testing case
from __future__ import division

import os
import numpy as np
import cv2
from scipy.misc import imresize
import scipy.misc 
import random

#from dataloaders.helpers import *
from torch.utils.data import Dataset

def flip(I,flip_p):
    if flip_p>0.5:
        return np.fliplr(I)
    else:
        return I

def scale_im(img_temp,scale):
    new_dims = (  int(img_temp.shape[0]*scale),  int(img_temp.shape[1]*scale)   )
    return cv2.resize(img_temp,new_dims).astype(float)

def scale_gt(img_temp,scale):
    new_dims = (  int(img_temp.shape[0]*scale),  int(img_temp.shape[1]*scale)   )
    return cv2.resize(img_temp,new_dims,interpolation = cv2.INTER_NEAREST).astype(float)

def my_crop(img,gt):
    H = int(0.9 * img.shape[0])
    W = int(0.9 * img.shape[1])
    H_offset = random.choice(range(img.shape[0] - H))
    W_offset = random.choice(range(img.shape[1] - W))
    H_slice = slice(H_offset, H_offset + H)
    W_slice = slice(W_offset, W_offset + W)
    img = img[H_slice, W_slice, :]
    gt = gt[H_slice, W_slice]
    
    return img, gt

class PairwiseImg(Dataset):
    """DAVIS 2016 dataset constructed using the PyTorch built-in functionalities"""

    def __init__(self, train=True,
                 inputRes=None,
                 db_root_dir='/DAVIS-2016',
                 transform=None,
                 meanval=(104.00699, 116.66877, 122.67892),
                 seq_name=None, sample_range=10):
        """Loads image to label pairs for tool pose estimation
        db_root_dir: dataset directory with subfolders "JPEGImages" and "Annotations"
        """
        self.train = train
        self.range = sample_range
        self.inputRes = inputRes
        self.db_root_dir = db_root_dir
        self.transform = transform
        self.meanval = meanval
        self.seq_name = seq_name

        if self.train:
            fname = 'train_seqs'
        else:
            fname = 'val_seqs'

        if self.seq_name is None: #所有的数据集都参与训练
            with open(os.path.join(db_root_dir, fname + '.txt')) as f:
                seqs = f.readlines()
                img_list = []
                labels = []
                Index = {}
                for seq in seqs:                    
                    images = np.sort(os.listdir(os.path.join(db_root_dir, 'JPEGImages/480p/', seq.strip('\n'))))
                    images_path = list(map(lambda x: os.path.join('JPEGImages/480p/', seq.strip(), x), images))
                    start_num = len(img_list)
                    img_list.extend(images_path)
                    end_num = len(img_list)
                    Index[seq.strip('\n')]= np.array([start_num, end_num])
                    lab = np.sort(os.listdir(os.path.join(db_root_dir, 'Annotations/480p/', seq.strip('\n'))))
                    lab_path = list(map(lambda x: os.path.join('Annotations/480p/', seq.strip(), x), lab))
                    labels.extend(lab_path)
        else: #针对所有的训练样本, img_list存放的是图片的路径

            # Initialize the per sequence images for online training
            names_img = np.sort(os.listdir(os.path.join(db_root_dir, str(seq_name))))
            img_list = list(map(lambda x: os.path.join(( str(seq_name)), x), names_img))
            #name_label = np.sort(os.listdir(os.path.join(db_root_dir,  str(seq_name))))
            labels = [os.path.join( (str(seq_name)+'/saliencymaps'), names_img[0])]
            labels.extend([None]*(len(names_img)-1)) #在labels这个列表后面添加元素None
            if self.train:
                img_list = [img_list[0]]
                labels = [labels[0]]

        assert (len(labels) == len(img_list))

        self.img_list = img_list
        self.labels = labels
        self.Index = Index
        #img_files = open('all_im.txt','w+')

    def __len__(self):
        return len(self.img_list)

    def __getitem__(self, idx):
        target, target_gt,sequence_name = self.make_img_gt_pair(idx) #测试时候要分割的帧
        target_id = idx
        seq_name1 = self.img_list[target_id].split('/')[-2] #获取视频名称
        sample = {'target': target, 'target_gt': target_gt, 'seq_name': sequence_name, 'search_0': None}
        if self.range>=1:
            my_index = self.Index[seq_name1]
            search_num = list(range(my_index[0], my_index[1]))  
            search_ids = random.sample(search_num, self.range)#min(len(self.img_list)-1, target_id+np.random.randint(1,self.range+1))
        
            for i in range(0,self.range):
                search_id = search_ids[i]
                search, search_gt,sequence_name = self.make_img_gt_pair(search_id)
                if sample['search_0'] is None:
                    sample['search_0'] = search
                else:
                    sample['search'+'_'+str(i)] = search
            #np.save('search1.npy',search)
            #np.save('search_gt.npy',search_gt)
            if self.seq_name is not None:
                fname = os.path.join(self.seq_name, "%05d" % idx)
                sample['fname'] = fname
       
        else:
            img, gt = self.make_img_gt_pair(idx)
            sample = {'image': img, 'gt': gt}
            if self.seq_name is not None:
                fname = os.path.join(self.seq_name, "%05d" % idx)
                sample['fname'] = fname

        return sample  #这个类最后的输出

    def make_img_gt_pair(self, idx): #这个函数存在的意义是为了getitem函数服务的
        """
        Make the image-ground-truth pair
        """
        img = cv2.imread(os.path.join(self.db_root_dir, self.img_list[idx]), cv2.IMREAD_COLOR)
        if self.labels[idx] is not None and self.train:
            label = cv2.imread(os.path.join(self.db_root_dir, self.labels[idx]), cv2.IMREAD_GRAYSCALE)
            #print(os.path.join(self.db_root_dir, self.labels[idx]))
        else:
            gt = np.zeros(img.shape[:-1], dtype=np.uint8)
            
         ## 已经读取了image以及对应的ground truth可以进行data augmentation了
        if self.train:  #scaling, cropping and flipping
             img, label = my_crop(img,label)
             scale = random.uniform(0.7, 1.3)
             flip_p = random.uniform(0, 1)
             img_temp = scale_im(img,scale)
             img_temp = flip(img_temp,flip_p)
             gt_temp = scale_gt(label,scale)
             gt_temp = flip(gt_temp,flip_p)
             
             img = img_temp
             label = gt_temp
             
        if self.inputRes is not None:
            img = imresize(img, self.inputRes)
            #print('ok1')
            #scipy.misc.imsave('label.png',label)
            #scipy.misc.imsave('img.png',img)
            if self.labels[idx] is not None and self.train:
                label = imresize(label, self.inputRes, interp='nearest')

        img = np.array(img, dtype=np.float32)
        #img = img[:, :, ::-1]
        img = np.subtract(img, np.array(self.meanval, dtype=np.float32))        
        img = img.transpose((2, 0, 1))  # NHWC -> NCHW
        
        if self.labels[idx] is not None and self.train:
                gt = np.array(label, dtype=np.int32)
                gt[gt!=0]=1
                #gt = gt/np.max([gt.max(), 1e-8])
        #np.save('gt.npy')
        sequence_name = self.img_list[idx].split('/')[2]
        return img, gt, sequence_name 

    def get_img_size(self):
        img = cv2.imread(os.path.join(self.db_root_dir, self.img_list[0]))
        
        return list(img.shape[:2])


if __name__ == '__main__':
    import custom_transforms as tr
    import torch
    from torchvision import transforms
    from matplotlib import pyplot as plt

    transforms = transforms.Compose([tr.RandomHorizontalFlip(), tr.Resize(scales=[0.5, 0.8, 1]), tr.ToTensor()])

    #dataset = DAVIS2016(db_root_dir='/media/eec/external/Databases/Segmentation/DAVIS-2016',
                       # train=True, transform=transforms)
    #dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=1)
#
#    for i, data in enumerate(dataloader):
#        plt.figure()
#        plt.imshow(overlay_mask(im_normalize(tens2image(data['image'])), tens2image(data['gt'])))
#        if i == 10:
#            break
#
#    plt.show(block=True)


================================================
FILE: dataloaders/PairwiseImg_video.py
================================================
# -*- coding: utf-8 -*-
"""
Created on Wed Sep 12 11:39:54 2018

@author: carri
"""

from __future__ import division

import os
import numpy as np
import cv2
from scipy.misc import imresize
import scipy.misc 
import random

from dataloaders.helpers import *
from torch.utils.data import Dataset

def flip(I,flip_p):
    if flip_p>0.5:
        return np.fliplr(I)
    else:
        return I

def scale_im(img_temp,scale):
    new_dims = (  int(img_temp.shape[0]*scale),  int(img_temp.shape[1]*scale)   )
    return cv2.resize(img_temp,new_dims).astype(float)

def scale_gt(img_temp,scale):
    new_dims = (  int(img_temp.shape[0]*scale),  int(img_temp.shape[1]*scale)   )
    return cv2.resize(img_temp,new_dims,interpolation = cv2.INTER_NEAREST).astype(float)

def my_crop(img,gt):
    H = int(0.9 * img.shape[0])
    W = int(0.9 * img.shape[1])
    H_offset = random.choice(range(img.shape[0] - H))
    W_offset = random.choice(range(img.shape[1] - W))
    H_slice = slice(H_offset, H_offset + H)
    W_slice = slice(W_offset, W_offset + W)
    img = img[H_slice, W_slice, :]
    gt = gt[H_slice, W_slice]
    
    return img, gt

class PairwiseImg(Dataset):
    """DAVIS 2016 dataset constructed using the PyTorch built-in functionalities"""

    def __init__(self, train=True,
                 inputRes=None,
                 db_root_dir='/DAVIS-2016',
                 img_root_dir = None,
                 transform=None,
                 meanval=(104.00699, 116.66877, 122.67892),
                 seq_name=None, sample_range=10):
        """Loads image to label pairs for tool pose estimation
        db_root_dir: dataset directory with subfolders "JPEGImages" and "Annotations"
        """
        self.train = train
        self.range = sample_range
        self.inputRes = inputRes
        self.img_root_dir = img_root_dir
        self.db_root_dir = db_root_dir
        self.transform = transform
        self.meanval = meanval
        self.seq_name = seq_name

        if self.train:
            fname = 'train_seqs'
        else:
            fname = 'val_seqs'

        if self.seq_name is None: #所有的数据集都参与训练
            with open(os.path.join(db_root_dir, fname + '.txt')) as f:
                seqs = f.readlines()
                video_list = []
                labels = []
                Index = {}
                image_list = []
                im_label = []
                for seq in seqs:                    
                    images = np.sort(os.listdir(os.path.join(db_root_dir, 'JPEGImages/480p/', seq.strip('\n'))))
                    images_path = list(map(lambda x: os.path.join('JPEGImages/480p/', seq.strip(), x), images))
                    start_num = len(video_list)
                    video_list.extend(images_path)
                    end_num = len(video_list)
                    Index[seq.strip('\n')]= np.array([start_num, end_num])
                    lab = np.sort(os.listdir(os.path.join(db_root_dir, 'Annotations/480p/', seq.strip('\n'))))
                    lab_path = list(map(lambda x: os.path.join('Annotations/480p/', seq.strip(), x), lab))
                    labels.extend(lab_path)
                    
                with open('/home/ubuntu/xiankai/saliency_data.txt') as f:
                    seqs = f.readlines()
                #data_list = np.sort(os.listdir(db_root_dir))
                    for seq in seqs: #所有数据集
                        seq = seq.strip('\n') 
                        images = np.sort(os.listdir(os.path.join(img_root_dir,seq.strip())+'/images/'))#针对某个数据集,比如DUT			
            # Initialize the original DAVIS splits for training the parent network
                        images_path = list(map(lambda x: os.path.join((seq +'/images'), x), images))         
                        image_list.extend(images_path)
                        lab = np.sort(os.listdir(os.path.join(img_root_dir,seq.strip())+'/saliencymaps'))
                        lab_path = list(map(lambda x: os.path.join((seq +'/saliencymaps'),x), lab))
                        im_label.extend(lab_path)
        else: #针对所有的训练样本, video_list存放的是图片的路径

            # Initialize the per sequence images for online training
            names_img = np.sort(os.listdir(os.path.join(db_root_dir, str(seq_name))))
            video_list = list(map(lambda x: os.path.join(( str(seq_name)), x), names_img))
            #name_label = np.sort(os.listdir(os.path.join(db_root_dir,  str(seq_name))))
            labels = [os.path.join( (str(seq_name)+'/saliencymaps'), names_img[0])]
            labels.extend([None]*(len(names_img)-1)) #在labels这个列表后面添加元素None
            if self.train:
                video_list = [video_list[0]]
                labels = [labels[0]]

        assert (len(labels) == len(video_list))

        self.video_list = video_list
        self.labels = labels
        self.image_list = image_list
        self.img_labels = im_label
        self.Index = Index
        #img_files = open('all_im.txt','w+')

    def __len__(self):
        print(len(self.video_list), len(self.image_list))
        return len(self.video_list)
    
    def __getitem__(self, idx):
        target, target_gt = self.make_video_gt_pair(idx)
        target_id = idx
        img_idx = np.random.randint(1,len(self.image_list)-1)
        seq_name1 = self.video_list[idx].split('/')[-2] #获取视频名称
        if self.train:
            my_index = self.Index[seq_name1]
            search_id = np.random.randint(my_index[0], my_index[1])#min(len(self.video_list)-1, target_id+np.random.randint(1,self.range+1))
            if search_id == target_id:
                search_id = np.random.randint(my_index[0], my_index[1])
            search, search_gt = self.make_video_gt_pair(search_id)
            img, img_gt = self.make_img_gt_pair(img_idx)
            sample = {'target': target, 'target_gt': target_gt, 'search': search, 'search_gt': search_gt, \
                      'img': img, 'img_gt': img_gt}
            #np.save('search1.npy',search)
            #np.save('search_gt.npy',search_gt)
            if self.seq_name is not None:
                fname = os.path.join(self.seq_name, "%05d" % idx)
                sample['fname'] = fname

            if self.transform is not None:
                sample = self.transform(sample)
       
        else:
            img, gt = self.make_video_gt_pair(idx)
            sample = {'image': img, 'gt': gt}
            if self.seq_name is not None:
                fname = os.path.join(self.seq_name, "%05d" % idx)
                sample['fname'] = fname
        
        
        
        return sample  #这个类最后的输出

    def make_video_gt_pair(self, idx): #这个函数存在的意义是为了getitem函数服务的
        """
        Make the image-ground-truth pair
        """
        img = cv2.imread(os.path.join(self.db_root_dir, self.video_list[idx]), cv2.IMREAD_COLOR)
        if self.labels[idx] is not None and self.train:
            label = cv2.imread(os.path.join(self.db_root_dir, self.labels[idx]), cv2.IMREAD_GRAYSCALE)
            #print(os.path.join(self.db_root_dir, self.labels[idx]))
        else:
            gt = np.zeros(img.shape[:-1], dtype=np.uint8)
            
         ## 已经读取了image以及对应的ground truth可以进行data augmentation了
        if self.train:  #scaling, cropping and flipping
             img, label = my_crop(img,label)
             scale = random.uniform(0.7, 1.3)
             flip_p = random.uniform(0, 1)
             img_temp = scale_im(img,scale)
             img_temp = flip(img_temp,flip_p)
             gt_temp = scale_gt(label,scale)
             gt_temp = flip(gt_temp,flip_p)
             
             img = img_temp
             label = gt_temp
             
        if self.inputRes is not None:
            img = imresize(img, self.inputRes)
            #print('ok1')
            #scipy.misc.imsave('label.png',label)
            #scipy.misc.imsave('img.png',img)
            if self.labels[idx] is not None and self.train:
                label = imresize(label, self.inputRes, interp='nearest')

        img = np.array(img, dtype=np.float32)
        #img = img[:, :, ::-1]
        img = np.subtract(img, np.array(self.meanval, dtype=np.float32))        
        img = img.transpose((2, 0, 1))  # NHWC -> NCHW
        
        if self.labels[idx] is not None and self.train:
                gt = np.array(label, dtype=np.int32)
                gt[gt!=0]=1
                #gt = gt/np.max([gt.max(), 1e-8])
        #np.save('gt.npy')
        return img, gt

    def get_img_size(self):
        img = cv2.imread(os.path.join(self.db_root_dir, self.video_list[0]))
        
        return list(img.shape[:2])

    def make_img_gt_pair(self, idx): #这个函数存在的意义是为了getitem函数服务的
        """
        Make the image-ground-truth pair
        """
        img = cv2.imread(os.path.join(self.img_root_dir, self.image_list[idx]),cv2.IMREAD_COLOR)
        #print(os.path.join(self.db_root_dir, self.img_list[idx]))
        if self.img_labels[idx] is not None and self.train:
            label = cv2.imread(os.path.join(self.img_root_dir, self.img_labels[idx]),cv2.IMREAD_GRAYSCALE)
            #print(os.path.join(self.db_root_dir, self.labels[idx]))
        else:
            gt = np.zeros(img.shape[:-1], dtype=np.uint8)
            
        if self.inputRes is not None:            
            img = imresize(img, self.inputRes)
            if self.img_labels[idx] is not None and self.train:
                label = imresize(label, self.inputRes, interp='nearest')

        img = np.array(img, dtype=np.float32)
        #img = img[:, :, ::-1]
        img = np.subtract(img, np.array(self.meanval, dtype=np.float32))        
        img = img.transpose((2, 0, 1))  # NHWC -> NCHW
        
        if self.img_labels[idx] is not None and self.train:
                gt = np.array(label, dtype=np.int32)
                gt[gt!=0]=1
                #gt = gt/np.max([gt.max(), 1e-8])
        #np.save('gt.npy')
        return img, gt
    
if __name__ == '__main__':
    import custom_transforms as tr
    import torch
    from torchvision import transforms
    from matplotlib import pyplot as plt

    transforms = transforms.Compose([tr.RandomHorizontalFlip(), tr.Resize(scales=[0.5, 0.8, 1]), tr.ToTensor()])

    #dataset = DAVIS2016(db_root_dir='/media/eec/external/Databases/Segmentation/DAVIS-2016',
                       # train=True, transform=transforms)
    #dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=1)
#
#    for i, data in enumerate(dataloader):
#        plt.figure()
#        plt.imshow(overlay_mask(im_normalize(tens2image(data['image'])), tens2image(data['gt'])))
#        if i == 10:
#            break
#
#    plt.show(block=True)

================================================
FILE: dataloaders/PairwiseImg_video_test_try.py
================================================
# -*- coding: utf-8 -*-
"""
Created on Wed Sep 12 11:39:54 2018

@author: carri
"""
# for testing case
from __future__ import division

import os
import numpy as np
import cv2
from scipy.misc import imresize
import scipy.misc 
import random
import torch
from dataloaders.helpers import *
from torch.utils.data import Dataset

def flip(I,flip_p):
    if flip_p>0.5:
        return np.fliplr(I)
    else:
        return I

def scale_im(img_temp,scale):
    new_dims = (  int(img_temp.shape[0]*scale),  int(img_temp.shape[1]*scale)   )
    return cv2.resize(img_temp,new_dims).astype(float)

def scale_gt(img_temp,scale):
    new_dims = (  int(img_temp.shape[0]*scale),  int(img_temp.shape[1]*scale)   )
    return cv2.resize(img_temp,new_dims,interpolation = cv2.INTER_NEAREST).astype(float)

def my_crop(img,gt):
    H = int(0.9 * img.shape[0])
    W = int(0.9 * img.shape[1])
    H_offset = random.choice(range(img.shape[0] - H))
    W_offset = random.choice(range(img.shape[1] - W))
    H_slice = slice(H_offset, H_offset + H)
    W_slice = slice(W_offset, W_offset + W)
    img = img[H_slice, W_slice, :]
    gt = gt[H_slice, W_slice]
    
    return img, gt

class PairwiseImg(Dataset):
    """DAVIS 2016 dataset constructed using the PyTorch built-in functionalities"""

    def __init__(self, train=True,
                 inputRes=None,
                 db_root_dir='/DAVIS-2016',
                 transform=None,
                 meanval=(104.00699, 116.66877, 122.67892),
                 seq_name=None, sample_range=10):
        """Loads image to label pairs for tool pose estimation
        db_root_dir: dataset directory with subfolders "JPEGImages" and "Annotations"
        """
        self.train = train
        self.range = sample_range
        self.inputRes = inputRes
        self.db_root_dir = db_root_dir
        self.transform = transform
        self.meanval = meanval
        self.seq_name = seq_name

        if self.train:
            fname = 'train_seqs'
        else:
            fname = 'val_seqs'

        if self.seq_name is None: #所有的数据集都参与训练
            with open(os.path.join(db_root_dir, fname + '.txt')) as f:
                seqs = f.readlines()
                img_list = []
                labels = []
                Index = {}
                for seq in seqs:                    
                    images = np.sort(os.listdir(os.path.join(db_root_dir, 'JPEGImages/480p/', seq.strip('\n'))))
                    images_path = list(map(lambda x: os.path.join('JPEGImages/480p/', seq.strip(), x), images))
                    start_num = len(img_list)
                    img_list.extend(images_path)
                    end_num = len(img_list)
                    Index[seq.strip('\n')]= np.array([start_num, end_num])
                    lab = np.sort(os.listdir(os.path.join(db_root_dir, 'Annotations/480p/', seq.strip('\n'))))
                    lab_path = list(map(lambda x: os.path.join('Annotations/480p/', seq.strip(), x), lab))
                    labels.extend(lab_path)
        else: #针对所有的训练样本, img_list存放的是图片的路径

            # Initialize the per sequence images for online training
            names_img = np.sort(os.listdir(os.path.join(db_root_dir, str(seq_name))))
            img_list = list(map(lambda x: os.path.join(( str(seq_name)), x), names_img))
            #name_label = np.sort(os.listdir(os.path.join(db_root_dir,  str(seq_name))))
            labels = [os.path.join( (str(seq_name)+'/saliencymaps'), names_img[0])]
            labels.extend([None]*(len(names_img)-1)) #在labels这个列表后面添加元素None
            if self.train:
                img_list = [img_list[0]]
                labels = [labels[0]]

        assert (len(labels) == len(img_list))

        self.img_list = img_list
        self.labels = labels
        self.Index = Index
        #img_files = open('all_im.txt','w+')

    def __len__(self):
        return len(self.img_list)

    def __getitem__(self, idx):
        target, target_grt,sequence_name = self.make_img_gt_pair(idx) #测试时候要分割的帧
        target_id = idx
        seq_name1 = self.img_list[target_id].split('/')[-2] #获取视频名称

        #target_grts = torch.stack((torch.from_numpy(target_grt), torch.from_numpy(target_grt_1)))
        #print('video name', seq_name1 )
        sample = {'target': target, 'target_gt': target_grt, 'seq_name': sequence_name, 'search_0': None}
        if self.range>=1:
            my_index = self.Index[seq_name1]
            search_num = list(range(my_index[0], my_index[1]))
            search_ids = random.sample(search_num, self.range)#min(len(self.img_list)-1, target_id+np.random.randint(1,self.range+1))
            searchs=[]
            for i in range(0,self.range):

                search_id = search_ids[i]
                search, search_grt,sequence_name = self.make_img_gt_pair(search_id)
                searchs.append(torch.from_numpy(search))
                #search_grts = torch.stack((torch.from_numpy(search_grt), torch.from_numpy(search_grt_1)))
            if sample['search_0'] is None:
                sample['search_0'] = torch.stack(searchs,dim=0)
            else:
                sample['search'+'_'+str(i)] = torch.stack(searchs)
            #np.save('search1.npy',search)
            #np.save('search_gt.npy',search_gt)
            if self.seq_name is not None:
                fname = os.path.join(self.seq_name, "%05d" % idx)
                sample['fname'] = fname
       
        else:
            img, gt = self.make_img_gt_pair(idx)
            sample = {'image': img, 'gt': gt}
            if self.seq_name is not None:
                fname = os.path.join(self.seq_name, "%05d" % idx)
                sample['fname'] = fname

        return sample  #这个类最后的输出

    def make_img_gt_pair(self, idx): #这个函数存在的意义是为了getitem函数服务的
        """
        Make the image-ground-truth pair
        """
        img = cv2.imread(os.path.join(self.db_root_dir, self.img_list[idx]), cv2.IMREAD_COLOR)
        if self.labels[idx] is not None and self.train:
            label = cv2.imread(os.path.join(self.db_root_dir, self.labels[idx]), cv2.IMREAD_GRAYSCALE)
            #print(os.path.join(self.db_root_dir, self.labels[idx]))
        else:
            gt = np.zeros(img.shape[:-1], dtype=np.uint8)
            
         ## 已经读取了image以及对应的ground truth可以进行data augmentation了
        if self.train:  #scaling, cropping and flipping
             img, label = my_crop(img,label)
             scale = random.uniform(0.7, 1.3)
             flip_p = random.uniform(0, 1)
             img_temp = scale_im(img,scale)
             img_temp = flip(img_temp,flip_p)
             gt_temp = scale_gt(label,scale)
             gt_temp = flip(gt_temp,flip_p)
             
             img = img_temp
             label = gt_temp
             
        if self.inputRes is not None:
            img = imresize(img, self.inputRes)
            #print('ok1')
            #scipy.misc.imsave('label.png',label)
            #scipy.misc.imsave('img.png',img)
            if self.labels[idx] is not None and self.train:
                label = imresize(label, self.inputRes, interp='nearest')

        img = np.array(img, dtype=np.float32)
        #img = img[:, :, ::-1]
        img = np.subtract(img, np.array(self.meanval, dtype=np.float32))        
        img = img.transpose((2, 0, 1))  # NHWC -> NCHW
        
        if self.labels[idx] is not None and self.train:
                gt = np.array(label, dtype=np.int32)
                gt[gt!=0]=1
                #gt = gt/np.max([gt.max(), 1e-8])
        #np.save('gt.npy')
        sequence_name = self.img_list[idx].split('/')[2]
        return img, gt, sequence_name 

    def get_img_size(self):
        img = cv2.imread(os.path.join(self.db_root_dir, self.img_list[0]))
        
        return list(img.shape[:2])


if __name__ == '__main__':
    import custom_transforms as tr
    import torch
    from torchvision import transforms
    from matplotlib import pyplot as plt

    transforms = transforms.Compose([tr.RandomHorizontalFlip(), tr.Resize(scales=[0.5, 0.8, 1]), tr.ToTensor()])

    #dataset = DAVIS2016(db_root_dir='/media/eec/external/Databases/Segmentation/DAVIS-2016',
                       # train=True, transform=transforms)
    #dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=1)
#
#    for i, data in enumerate(dataloader):
#        plt.figure()
#        plt.imshow(overlay_mask(im_normalize(tens2image(data['image'])), tens2image(data['gt'])))
#        if i == 10:
#            break
#
#    plt.show(block=True)


================================================
FILE: dataloaders/PairwiseImg_video_try.py
================================================
# -*- coding: utf-8 -*-
"""
Created on Wed Sep 12 11:39:54 2018

@author: carri
"""

from __future__ import division

import os
import numpy as np
import cv2
from scipy.misc import imresize
import scipy.misc 
import random
import torch
from dataloaders.helpers import *
from torch.utils.data import Dataset

def flip(I,flip_p):
    if flip_p>0.5:
        return np.fliplr(I)
    else:
        return I

def scale_im(img_temp,scale):
    new_dims = (  int(img_temp.shape[0]*scale),  int(img_temp.shape[1]*scale)   )
    return cv2.resize(img_temp,new_dims).astype(float)

def scale_gt(img_temp,scale):
    new_dims = (  int(img_temp.shape[0]*scale),  int(img_temp.shape[1]*scale)   )
    return cv2.resize(img_temp,new_dims,interpolation = cv2.INTER_NEAREST).astype(float)

def my_crop(img,gt):
    H = int(0.9 * img.shape[0])
    W = int(0.9 * img.shape[1])
    H_offset = random.choice(range(img.shape[0] - H))
    W_offset = random.choice(range(img.shape[1] - W))
    H_slice = slice(H_offset, H_offset + H)
    W_slice = slice(W_offset, W_offset + W)
    img = img[H_slice, W_slice, :]
    gt = gt[H_slice, W_slice]
    
    return img, gt

class PairwiseImg(Dataset):
    """DAVIS 2016 dataset constructed using the PyTorch built-in functionalities"""

    def __init__(self, train=True,
                 inputRes=None,
                 db_root_dir='/DAVIS-2016',
                 img_root_dir = None,
                 transform=None,
                 meanval=(104.00699, 116.66877, 122.67892),
                 seq_name=None, sample_range=10):
        """Loads image to label pairs for tool pose estimation
        db_root_dir: dataset directory with subfolders "JPEGImages" and "Annotations"
        """
        self.train = train
        self.range = sample_range
        self.inputRes = inputRes
        self.img_root_dir = img_root_dir
        self.db_root_dir = db_root_dir
        self.transform = transform
        self.meanval = meanval
        self.seq_name = seq_name

        if self.train:
            fname = 'train_seqs'
        else:
            fname = 'val_seqs'

        if self.seq_name is None: #所有的数据集都参与训练
            with open(os.path.join(db_root_dir, fname + '.txt')) as f:
                seqs = f.readlines()
                video_list = []
                labels = []
                Index = {}
                image_list = []
                im_label = []
                for seq in seqs:                    
                    images = np.sort(os.listdir(os.path.join(db_root_dir, 'JPEGImages/480p/', seq.strip('\n'))))
                    images_path = list(map(lambda x: os.path.join('JPEGImages/480p/', seq.strip(), x), images))
                    start_num = len(video_list)
                    video_list.extend(images_path)
                    end_num = len(video_list)
                    Index[seq.strip('\n')]= np.array([start_num, end_num])
                    lab = np.sort(os.listdir(os.path.join(db_root_dir, 'Annotations/480p/', seq.strip('\n'))))
                    lab_path = list(map(lambda x: os.path.join('Annotations/480p/', seq.strip(), x), lab))
                    labels.extend(lab_path)
                    
                with open('/home/ubuntu/xiankai/saliency_data.txt') as f:
                    seqs = f.readlines()
                #data_list = np.sort(os.listdir(db_root_dir))
                    for seq in seqs: #所有数据集
                        seq = seq.strip('\n') 
                        images = np.sort(os.listdir(os.path.join(img_root_dir,seq.strip())+'/images/'))#针对某个数据集,比如DUT			
            # Initialize the original DAVIS splits for training the parent network
                        images_path = list(map(lambda x: os.path.join((seq +'/images'), x), images))         
                        image_list.extend(images_path)
                        lab = np.sort(os.listdir(os.path.join(img_root_dir,seq.strip())+'/saliencymaps'))
                        lab_path = list(map(lambda x: os.path.join((seq +'/saliencymaps'),x), lab))
                        im_label.extend(lab_path)
        else: #针对所有的训练样本, video_list存放的是图片的路径

            # Initialize the per sequence images for online training
            names_img = np.sort(os.listdir(os.path.join(db_root_dir, str(seq_name))))
            video_list = list(map(lambda x: os.path.join(( str(seq_name)), x), names_img))
            #name_label = np.sort(os.listdir(os.path.join(db_root_dir,  str(seq_name))))
            labels = [os.path.join( (str(seq_name)+'/saliencymaps'), names_img[0])]
            labels.extend([None]*(len(names_img)-1)) #在labels这个列表后面添加元素None
            if self.train:
                video_list = [video_list[0]]
                labels = [labels[0]]

        assert (len(labels) == len(video_list))

        self.video_list = video_list
        self.labels = labels
        self.image_list = image_list
        self.img_labels = im_label
        self.Index = Index
        #img_files = open('all_im.txt','w+')

    def __len__(self):
        print(len(self.video_list), len(self.image_list))
        return len(self.video_list)
    
    def __getitem__(self, idx):
        target, target_grt = self.make_video_gt_pair(idx)
        target_id = idx
        img_idx = random.sample([my_i for my_i in range(0,len(self.image_list))],2)

        seq_name1 = self.video_list[idx].split('/')[-2] #获取视频名称
        my_index = self.Index[seq_name1]
        video_idx = random.sample([my_i for my_i in range(my_index[0],my_index[1])],3)
        target_1, target_grt_1 = self.make_video_gt_pair(video_idx[0])
        #print('type:', type(target))

        #targets = torch.stack((torch.from_numpy(target),torch.from_numpy(target_1)))
        #target_grts = torch.stack((torch.from_numpy(target_grt),torch.from_numpy(target_grt_1)))
        #print('size:', torch.from_numpy(target_grt).size(), torch.from_numpy(target_grt_1).size())
        if self.train:
            #my_index = self.Index[seq_name1]
            search, search_grt = self.make_video_gt_pair(video_idx[1])
            search_1, search_grt_1 = self.make_video_gt_pair(video_idx[2])
            searchs = torch.stack((torch.from_numpy(search), torch.from_numpy(search_1)))
            search_grts = torch.stack((torch.from_numpy(search_grt), torch.from_numpy(search_grt_1)))
            img, img_grt = self.make_img_gt_pair(img_idx[0])
            #img_1, img_grt_1 = self.make_img_gt_pair(img_idx[1])
            #imgs = torch.stack((torch.from_numpy(img), torch.from_numpy(img_1)))
            #img_grts = torch.stack((torch.torch.from_numpy(img_grt), torch.from_numpy(img_grt_1)))
            sample = {'target': target, 'target_grt': target_grt, 'search': searchs, 'search_grt': search_grts, \
                      'img': img, 'img_grt': img_grt}
            #np.save('search1.npy',search)
            if self.seq_name is not None:
                fname = os.path.join(self.seq_name, "%05d" % idx)
                sample['fname'] = fname

            if self.transform is not None:
                sample = self.transform(sample)
       
        else:
            img, gt = self.make_video_gt_pair(idx)
            sample = {'image': img, 'gt': gt}
            if self.seq_name is not None:
                fname = os.path.join(self.seq_name, "%05d" % idx)
                sample['fname'] = fname
        
        
        
        return sample  #这个类最后的输出

    def make_video_gt_pair(self, idx): #这个函数存在的意义是为了getitem函数服务的
        """
        Make the image-ground-truth pair
        """
        img = cv2.imread(os.path.join(self.db_root_dir, self.video_list[idx]), cv2.IMREAD_COLOR)
        if self.labels[idx] is not None and self.train:
            label = cv2.imread(os.path.join(self.db_root_dir, self.labels[idx]), cv2.IMREAD_GRAYSCALE)
            #print(os.path.join(self.db_root_dir, self.labels[idx]))
        else:
            gt = np.zeros(img.shape[:-1], dtype=np.uint8)
            
         ## 已经读取了image以及对应的ground truth可以进行data augmentation了
        if self.train:  #scaling, cropping and flipping
             img, label = my_crop(img,label)
             scale = random.uniform(0.7, 1.3)
             flip_p = random.uniform(0, 1)
             img_temp = scale_im(img,scale)
             img_temp = flip(img_temp,flip_p)
             gt_temp = scale_gt(label,scale)
             gt_temp = flip(gt_temp,flip_p)
             
             img = img_temp
             label = gt_temp
             
        if self.inputRes is not None:
            img = imresize(img, self.inputRes)
            #print('ok1')
            #scipy.misc.imsave('label.png',label)
            #scipy.misc.imsave('img.png',img)
            if self.labels[idx] is not None and self.train:
                label = imresize(label, self.inputRes, interp='nearest')

        img = np.array(img, dtype=np.float32)
        #img = img[:, :, ::-1]
        img = np.subtract(img, np.array(self.meanval, dtype=np.float32))        
        img = img.transpose((2, 0, 1))  # NHWC -> NCHW
        
        if self.labels[idx] is not None and self.train:
                gt = np.array(label, dtype=np.int32)
                gt[gt!=0]=1
                #gt = gt/np.max([gt.max(), 1e-8])
        #np.save('gt.npy')
        return img, gt

    def get_img_size(self):
        img = cv2.imread(os.path.join(self.db_root_dir, self.video_list[0]))
        
        return list(img.shape[:2])

    def make_img_gt_pair(self, idx): #这个函数存在的意义是为了getitem函数服务的
        """
        Make the image-ground-truth pair
        """
        img = cv2.imread(os.path.join(self.img_root_dir, self.image_list[idx]),cv2.IMREAD_COLOR)
        #print(os.path.join(self.db_root_dir, self.img_list[idx]))
        if self.img_labels[idx] is not None and self.train:
            label = cv2.imread(os.path.join(self.img_root_dir, self.img_labels[idx]),cv2.IMREAD_GRAYSCALE)
            #print(os.path.join(self.db_root_dir, self.labels[idx]))
        else:
            gt = np.zeros(img.shape[:-1], dtype=np.uint8)
            
        if self.inputRes is not None:            
            img = imresize(img, self.inputRes)
            if self.img_labels[idx] is not None and self.train:
                label = imresize(label, self.inputRes, interp='nearest')

        img = np.array(img, dtype=np.float32)
        #img = img[:, :, ::-1]
        img = np.subtract(img, np.array(self.meanval, dtype=np.float32))        
        img = img.transpose((2, 0, 1))  # NHWC -> NCHW
        
        if self.img_labels[idx] is not None and self.train:
                gt = np.array(label, dtype=np.int32)
                gt[gt!=0]=1
                #gt = gt/np.max([gt.max(), 1e-8])
        #np.save('gt.npy')
        return img, gt
    
if __name__ == '__main__':
    import custom_transforms as tr
    import torch
    from torchvision import transforms
    from matplotlib import pyplot as plt

    transforms = transforms.Compose([tr.RandomHorizontalFlip(), tr.Resize(scales=[0.5, 0.8, 1]), tr.ToTensor()])

    #dataset = DAVIS2016(db_root_dir='/media/eec/external/Databases/Segmentation/DAVIS-2016',
                       # train=True, transform=transforms)
    #dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=1)
#
#    for i, data in enumerate(dataloader):
#        plt.figure()
#        plt.imshow(overlay_mask(im_normalize(tens2image(data['image'])), tens2image(data['gt'])))
#        if i == 10:
#            break
#
#    plt.show(block=True)

================================================
FILE: dataloaders/r
================================================



================================================
FILE: deeplab/__init__.py
================================================



================================================
FILE: deeplab/e
================================================



================================================
FILE: deeplab/siamese_model_conf.py
================================================
# -*- coding: utf-8 -*-
"""
Created on Sun Sep 16 10:01:14 2018

@author: carri
"""

import torch.nn as nn
import torch
import torch.nn.functional as F
from torch.nn import init
affine_par = True
#区别于siamese_model_concat的地方就是采用的最标准的deeplab_v3的基础网络,然后加上了非对称的分支

def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes, affine=affine_par)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes, affine=affine_par)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False)  # change
        self.bn1 = nn.BatchNorm2d(planes, affine=affine_par)
        padding = dilation
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1,  # change
                               padding=padding, bias=False, dilation=dilation)
        self.bn2 = nn.BatchNorm2d(planes, affine=affine_par)
        self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * 4, affine=affine_par)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class ASPP(nn.Module):
    def __init__(self, dilation_series, padding_series, depth):
        super(ASPP, self).__init__()
        self.mean = nn.AdaptiveAvgPool2d((1,1))
        self.conv= nn.Conv2d(2048, depth, 1,1)
        self.bn_x = nn.BatchNorm2d(depth)
        self.conv2d_0 = nn.Conv2d(2048, depth, kernel_size=1, stride=1)
        self.bn_0 = nn.BatchNorm2d(depth)
        self.conv2d_1 = nn.Conv2d(2048, depth, kernel_size=3, stride=1, padding=padding_series[0], dilation=dilation_series[0])
        self.bn_1 = nn.BatchNorm2d(depth)
        self.conv2d_2 = nn.Conv2d(2048, depth, kernel_size=3, stride=1, padding=padding_series[1], dilation=dilation_series[1])
        self.bn_2 = nn.BatchNorm2d(depth)
        self.conv2d_3 = nn.Conv2d(2048, depth, kernel_size=3, stride=1, padding=padding_series[2], dilation=dilation_series[2])
        self.bn_3 = nn.BatchNorm2d(depth)
        self.relu = nn.ReLU(inplace=True)
        self.bottleneck = nn.Conv2d( depth*5, 256, kernel_size=3, padding=1 )  #512 1x1Conv
        self.bn = nn.BatchNorm2d(256)
        self.prelu = nn.PReLU()
        #for m in self.conv2d_list:
        #    m.weight.data.normal_(0, 0.01)
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, 0.01)
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            
    def _make_stage_(self, dilation1, padding1):
        Conv = nn.Conv2d(2048, 256, kernel_size=3, stride=1, padding=padding1, dilation=dilation1, bias=True)#classes
        Bn = nn.BatchNorm2d(256)
        Relu = nn.ReLU(inplace=True)
        return nn.Sequential(Conv, Bn, Relu)
        

    def forward(self, x):
        #out = self.conv2d_list[0](x)
        #mulBranches = [conv2d_l(x) for conv2d_l in self.conv2d_list]
        size=x.shape[2:]
        image_features=self.mean(x)
        image_features=self.conv(image_features)
        image_features = self.bn_x(image_features)
        image_features = self.relu(image_features)
        image_features=F.upsample(image_features, size=size, mode='bilinear', align_corners=True)
        out_0 = self.conv2d_0(x)
        out_0 = self.bn_0(out_0) 
        out_0 = self.relu(out_0)
        out_1 = self.conv2d_1(x)
        out_1 = self.bn_1(out_1) 
        out_1 = self.relu(out_1)
        out_2 = self.conv2d_2(x)
        out_2 = self.bn_2(out_2) 
        out_2 = self.relu(out_2)
        out_3 = self.conv2d_3(x)
        out_3 = self.bn_3(out_3) 
        out_3 = self.relu(out_3)
        out = torch.cat([image_features, out_0, out_1, out_2, out_3], 1)
        out = self.bottleneck(out)
        out = self.bn(out)
        out = self.prelu(out)
        #for i in range(len(self.conv2d_list) - 1):
        #    out += self.conv2d_list[i + 1](x)
        
        return out
  


class ResNet(nn.Module):
    def __init__(self, block, layers, num_classes):
        self.inplanes = 64
        super(ResNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64, affine=affine_par)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=True)  # change
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4)
        self.layer5 = self._make_pred_layer(ASPP, [ 6, 12, 18], [6, 12, 18], 512)
        self.main_classifier = nn.Conv2d(256, num_classes, kernel_size=1)
        self.softmax = nn.Sigmoid()#nn.Softmax()
        
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, 0.01)
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def _make_layer(self, block, planes, blocks, stride=1, dilation=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion or dilation == 2 or dilation == 4:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion, affine=affine_par))
        for i in downsample._modules['1'].parameters():
            i.requires_grad = False
        layers = []
        layers.append(block(self.inplanes, planes, stride, dilation=dilation, downsample=downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes, dilation=dilation))

        return nn.Sequential(*layers)

    def _make_pred_layer(self, block, dilation_series, padding_series, num_classes):
        return block(dilation_series, padding_series, num_classes)

    def forward(self, x):
        input_size = x.size()[2:]
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        fea = self.layer5(x)
        x = self.main_classifier(fea)
        #print("before upsample, tensor size:", x.size())
        x = F.upsample(x, input_size, mode='bilinear')  #upsample to the size of input image, scale=8
        #print("after upsample, tensor size:", x.size())
        x = self.softmax(x)
        return fea, x

class CoattentionModel(nn.Module):
    def  __init__(self, block, layers, num_classes, all_channel=256, all_dim=60*60):	#473./8=60	
        super(CoattentionModel, self).__init__()
        self.encoder = ResNet(block, layers, num_classes)
        self.linear_e = nn.Linear(all_channel, all_channel,bias = False)
        self.channel = all_channel
        self.dim = all_dim
        self.gate = nn.Conv2d(all_channel, 1, kernel_size  = 1, bias = False)
        self.gate_s = nn.Sigmoid()
        self.conv1 = nn.Conv2d(all_channel*2, all_channel, kernel_size=3, padding=1, bias = False)
        self.conv2 = nn.Conv2d(all_channel*2, all_channel, kernel_size=3, padding=1, bias = False)
        self.bn1 = nn.BatchNorm2d(all_channel)
        self.bn2 = nn.BatchNorm2d(all_channel)
        self.prelu = nn.ReLU(inplace=True)
        self.main_classifier1 = nn.Conv2d(all_channel, num_classes, kernel_size=1, bias = True)
        self.main_classifier2 = nn.Conv2d(all_channel, num_classes, kernel_size=1, bias = True)
        self.softmax = nn.Sigmoid()
        
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                #n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, 0.01)
                #init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                #init.xavier_normal(m.weight.data)
                #m.bias.data.fill_(0)
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()    
    
		
    def forward(self, input1, input2): #注意input2 可以是多帧图像
        
        #input1_att, input2_att = self.coattention(input1, input2) 
        input_size = input1.size()[2:]
        exemplar, temp = self.encoder(input1)
        query, temp = self.encoder(input2)		 
        fea_size = query.size()[2:]	 
        all_dim = fea_size[0]*fea_size[1]
        exemplar_flat = exemplar.view(-1, query.size()[1], all_dim) #N,C,H*W
        query_flat = query.view(-1, query.size()[1], all_dim)
        exemplar_t = torch.transpose(exemplar_flat,1,2).contiguous()  #batch size x dim x num
        exemplar_corr = self.linear_e(exemplar_t) # 
        A = torch.bmm(exemplar_corr, query_flat)
        A1 = F.softmax(A.clone(), dim = 1) #
        B = F.softmax(torch.transpose(A,1,2),dim=1)
        query_att = torch.bmm(exemplar_flat, A1).contiguous() #注意我们这个地方要不要用交互以及Residual的结构
        exemplar_att = torch.bmm(query_flat, B).contiguous()
        
        input1_att = exemplar_att.view(-1, query.size()[1], fea_size[0], fea_size[1])  
        input2_att = query_att.view(-1, query.size()[1], fea_size[0], fea_size[1])
        input1_mask = self.gate(input1_att)
        input2_mask = self.gate(input2_att)
        input1_mask = self.gate_s(input1_mask)
        input2_mask = self.gate_s(input2_mask)
        input1_att = input1_att * input1_mask
        input2_att = input2_att * input2_mask
        input1_att = torch.cat([input1_att, exemplar],1) 
        input2_att = torch.cat([input2_att, query],1)
        input1_att  = self.conv1(input1_att )
        input2_att  = self.conv2(input2_att ) 
        input1_att  = self.bn1(input1_att )
        input2_att  = self.bn2(input2_att )
        input1_att  = self.prelu(input1_att )
        input2_att  = self.prelu(input2_att )
        x1 = self.main_classifier1(input1_att)
        x2 = self.main_classifier2(input2_att)   
        x1 = F.upsample(x1, input_size, mode='bilinear')  #upsample to the size of input image, scale=8
        x2 = F.upsample(x2, input_size, mode='bilinear')  #upsample to the size of input image, scale=8
        #print("after upsample, tensor size:", x.size())
        x1 = self.softmax(x1)
        x2 = self.softmax(x2)
        
#        x1 = self.softmax(x1)
#        x2 = self.softmax(x2)
        return x1, x2, temp  #shape: NxCx	
    

def Res_Deeplab(num_classes=2):
    model = ResNet(Bottleneck, [3, 4, 23, 3], num_classes-1)
    return model

def CoattentionNet(num_classes=2):
    model = CoattentionModel(Bottleneck,[3, 4, 23, 3], num_classes-1)
	
    return model


================================================
FILE: deeplab/siamese_model_conf_try_single.py
================================================
# -*- coding: utf-8 -*-
"""
Created on Sun Sep 16 10:01:14 2018

@author: carri
"""

import torch.nn as nn
import torch
import torch.nn.functional as F
from torch.nn import init
affine_par = True
import numpy as np
#区别于siamese_model_concat的地方就是采用的最标准的deeplab_v3的基础网络,然后加上了非对称的分支

def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes, affine=affine_par)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes, affine=affine_par)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False)  # change
        self.bn1 = nn.BatchNorm2d(planes, affine=affine_par)
        padding = dilation
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1,  # change
                               padding=padding, bias=False, dilation=dilation)
        self.bn2 = nn.BatchNorm2d(planes, affine=affine_par)
        self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * 4, affine=affine_par)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class ASPP(nn.Module):
    def __init__(self, dilation_series, padding_series, depth):
        super(ASPP, self).__init__()
        self.mean = nn.AdaptiveAvgPool2d((1,1))
        self.conv= nn.Conv2d(2048, depth, 1,1)
        self.bn_x = nn.BatchNorm2d(depth)
        self.conv2d_0 = nn.Conv2d(2048, depth, kernel_size=1, stride=1)
        self.bn_0 = nn.BatchNorm2d(depth)
        self.conv2d_1 = nn.Conv2d(2048, depth, kernel_size=3, stride=1, padding=padding_series[0], dilation=dilation_series[0])
        self.bn_1 = nn.BatchNorm2d(depth)
        self.conv2d_2 = nn.Conv2d(2048, depth, kernel_size=3, stride=1, padding=padding_series[1], dilation=dilation_series[1])
        self.bn_2 = nn.BatchNorm2d(depth)
        self.conv2d_3 = nn.Conv2d(2048, depth, kernel_size=3, stride=1, padding=padding_series[2], dilation=dilation_series[2])
        self.bn_3 = nn.BatchNorm2d(depth)
        self.relu = nn.ReLU(inplace=True)
        self.bottleneck = nn.Conv2d( depth*5, 256, kernel_size=3, padding=1 )  #512 1x1Conv
        self.bn = nn.BatchNorm2d(256)
        self.prelu = nn.PReLU()
        #for m in self.conv2d_list:
        #    m.weight.data.normal_(0, 0.01)
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, 0.01)
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            
    def _make_stage_(self, dilation1, padding1):
        Conv = nn.Conv2d(2048, 256, kernel_size=3, stride=1, padding=padding1, dilation=dilation1, bias=True)#classes
        Bn = nn.BatchNorm2d(256)
        Relu = nn.ReLU(inplace=True)
        return nn.Sequential(Conv, Bn, Relu)
        

    def forward(self, x):
        #out = self.conv2d_list[0](x)
        #mulBranches = [conv2d_l(x) for conv2d_l in self.conv2d_list]
        size=x.shape[2:]
        image_features=self.mean(x)
        image_features=self.conv(image_features)
        image_features = self.bn_x(image_features)
        image_features = self.relu(image_features)
        image_features=F.upsample(image_features, size=size, mode='bilinear', align_corners=True)
        out_0 = self.conv2d_0(x)
        out_0 = self.bn_0(out_0) 
        out_0 = self.relu(out_0)
        out_1 = self.conv2d_1(x)
        out_1 = self.bn_1(out_1) 
        out_1 = self.relu(out_1)
        out_2 = self.conv2d_2(x)
        out_2 = self.bn_2(out_2) 
        out_2 = self.relu(out_2)
        out_3 = self.conv2d_3(x)
        out_3 = self.bn_3(out_3) 
        out_3 = self.relu(out_3)
        out = torch.cat([image_features, out_0, out_1, out_2, out_3], 1)
        out = self.bottleneck(out)
        out = self.bn(out)
        out = self.prelu(out)
        #for i in range(len(self.conv2d_list) - 1):
        #    out += self.conv2d_list[i + 1](x)
        
        return out
  


class ResNet(nn.Module):
    def __init__(self, block, layers, num_classes):
        self.inplanes = 64
        super(ResNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64, affine=affine_par)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=True)  # change
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4)
        self.layer5 = self._make_pred_layer(ASPP, [ 6, 12, 18], [6, 12, 18], 512)
        self.main_classifier = nn.Conv2d(256, num_classes, kernel_size=1)
        self.softmax = nn.Sigmoid()#nn.Softmax()
        
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, 0.01)
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def _make_layer(self, block, planes, blocks, stride=1, dilation=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion or dilation == 2 or dilation == 4:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion, affine=affine_par))
        for i in downsample._modules['1'].parameters():
            i.requires_grad = False
        layers = []
        layers.append(block(self.inplanes, planes, stride, dilation=dilation, downsample=downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes, dilation=dilation))

        return nn.Sequential(*layers)

    def _make_pred_layer(self, block, dilation_series, padding_series, num_classes):
        return block(dilation_series, padding_series, num_classes)

    def forward(self, x):
        input_size = x.size()[2:]
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        fea = self.layer5(x)
        x = self.main_classifier(fea)
        #print("before upsample, tensor size:", x.size())
        x = F.upsample(x, input_size, mode='bilinear')  #upsample to the size of input image, scale=8
        #print("after upsample, tensor size:", x.size())
        x = self.softmax(x)
        return fea, x

class CoattentionModel(nn.Module):
    def  __init__(self, block, layers, num_classes,  all_channel=256, all_dim=60*60):	#473./8=60
        super(CoattentionModel, self).__init__()
        self.nframes = 2
        self.encoder = ResNet(block, layers, num_classes)
        self.linear_e = nn.Linear(all_channel, all_channel,bias = False)
        self.channel = all_channel
        self.dim = all_dim
        self.gate = nn.Conv2d(all_channel, 1, kernel_size  = 1, bias = False)
        self.gate_s = nn.Sigmoid()
        self.conv1 = nn.Conv2d(all_channel*2, all_channel, kernel_size=3, padding=1, bias = False)
        self.conv2 = nn.Conv2d(all_channel*2, all_channel, kernel_size=3, padding=1, bias = False)
        self.bn1 = nn.BatchNorm2d(all_channel, affine=affine_par)
        self.bn2 = nn.BatchNorm2d(all_channel, affine=affine_par)
        self.prelu = nn.ReLU(inplace=True)
        self.main_classifier1 = nn.Conv2d(all_channel, num_classes, kernel_size=1, bias = True)
        self.main_classifier2 = nn.Conv2d(all_channel, num_classes, kernel_size=1, bias = True)
        self.softmax = nn.Sigmoid()

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                #n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, 0.01)
                #init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                #init.xavier_normal(m.weight.data)
                #m.bias.data.fill_(0)
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()    
    
		
    def forward(self, input1, input2): #注意input2 可以是多帧图像
        
        #input1_att, input2_att = self.coattention(input1, input2)
        exemplar, temp = self.encoder(input1)

        #print('feature size:', input1.size())
        if len(input2.size() )>4:
            B, N, C, H, W = input2.size()  # 2,2,3,473,473
            input_size = [H, W]
            video_frames2 = [elem.view(B, C, H, W) for elem in input2.split(1, dim=1)]
            # the length of exemplars is equal to the nframes
            querys= [self.encoder(video_frames2[frame]) for frame in range(0,self.nframes)]
            #query = torch.cat([querys[][0]],dim=1)
            #query1 = torch.cat([querys[1]], dim=1)
            query = torch.cat(([querys[frame][0] for frame in range(0,self.nframes )]), dim=2)
            #print('query size:', query.size()) 2*512*49*49
            predict_mask = torch.cat(([querys[frame][1] for frame in range(0,self.nframes )]), dim=1)
            #print('feature size:', exemplar.size())
            fea_size = exemplar.size()[2:]
            exemplar_flat = exemplar.view(-1, self.channel, fea_size[0]*fea_size[1]) #N,C,H*W
            exemplar_t = torch.transpose(exemplar_flat, 1, 2).contiguous()  # batch size x dim x num
            exemplar_corr = self.linear_e(exemplar_t)
            #coattention_fea = 0
            query_flat = query.view(-1, self.channel, self.nframes*fea_size[0]*fea_size[1])
            A = torch.bmm(exemplar_corr, query_flat)
            A = F.softmax(A, dim = 1) #
            B = F.softmax(torch.transpose(A,1,2),dim=1)
            query_att = torch.bmm(exemplar_flat, A).contiguous() #注意我们这个地方要不要用交互以及Residual的结构
            exemplar_att = torch.bmm(query_flat, B).contiguous()

            input1_att = exemplar_att.view(-1, self.channel, fea_size[0], fea_size[1])
            input2_att = query_att.view(-1, self.channel, self.nframes*fea_size[0], fea_size[1])
            input1_mask = self.gate(input1_att)
            #input2_mask = self.gate(input2_att)
            input1_mask = self.gate_s(input1_mask)
            #input2_mask = self.gate_s(input2_mask)
            input1_att_org = input1_att * input1_mask
            #coattention_fea = coattention_fea + input1_att_org

            #print('h_v size, h_v_org size:', torch.max(input1_att), torch.max(exemplar), torch.min(input1_att), torch.max(exemplar))
            input1_att = torch.cat([input1_att_org, exemplar],1)
            input1_att  = self.conv1(input1_att )
            input1_att  = self.bn1(input1_att )
            input1_att  = self.prelu(input1_att )
            x1 = self.main_classifier1(input1_att)
            x1 = F.upsample(x1, input_size, mode='bilinear')  #upsample to the size of input image, scale=8
              #upsample to the size of input image, scale=8
            #print("after upsample, tensor size:", x.size())
            x1  = self.softmax(x1)
        else:
            x1 = exemplar

        return x1, temp  #shape: NxCx


def CoattentionNet(num_classes=2,nframes=2):
    model = CoattentionModel(Bottleneck,[3, 4, 23, 3], num_classes-1)
	
    return model

================================================
FILE: deeplab/utils.py
================================================
import torch
#from tensorboard_logger import log_value
from torch.autograd import Variable


def loss_calc(pred, label, ignore_label):
    """
    This function returns cross entropy loss for semantic segmentation
    """
    # out shape batch_size x channels x h x w -> batch_size x channels x h x w
    # label shape h x w x 1 x batch_size  -> batch_size x 1 x h x w
    label = Variable(label.long()).cuda()
    criterion = torch.nn.CrossEntropyLoss(ignore_index=ignore_label).cuda()

    return criterion(pred, label)


def lr_poly(base_lr, iter, max_iter, power):
    return base_lr * ((1 - float(iter) / max_iter) ** power)


def get_1x_lr_params(model):
    """
    This generator returns all the parameters of the net except for
    the last classification layer. Note that for each batchnorm layer,
    requires_grad is set to False in deeplab_resnet.py, therefore this function does not return
    any batchnorm parameter
    """
    b = [model.conv1, model.bn1, model.layer1, model.layer2, model.layer3, model.layer4]

    for i in range(len(b)):
        for j in b[i].modules():
            jj = 0
            for k in j.parameters():
                jj += 1
                if k.requires_grad:
                    yield k


def get_10x_lr_params(model):
    """
    This generator returns all the parameters for the last layer of the net,
    which does the classification of pixel into classes
    """
    b = [model.layer5.parameters(), model.main_classifier.parameters()]

    for j in range(len(b)):
        for i in b[j]:
            yield i


def adjust_learning_rate(optimizer, i_iter, learning_rate, num_steps, power):
    """Sets the learning rate to the initial LR divided by 5 at 60th, 120th and 160th epochs"""
    lr = lr_poly(learning_rate, i_iter, num_steps, power)
    #log_value('learning', lr, i_iter)
    optimizer.param_groups[0]['lr'] = lr
    optimizer.param_groups[1]['lr'] = lr * 10


================================================
FILE: densecrf_apply_cvpr2019.py
================================================
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Fri Mar  1 20:37:37 2019

@author: xiankai
"""

import pydensecrf.densecrf as dcrf
import numpy as np
import sys
import os


from skimage.io import imread, imsave
from pydensecrf.utils import unary_from_labels, create_pairwise_bilateral, create_pairwise_gaussian, unary_from_softmax
from os import listdir, makedirs
from os.path import isfile, join
from multiprocessing import Process

            
def worker(scale, g_dim, g_factor,s_dim,C_dim,c_factor):
    davis_path = '/home/xiankai/work/DAVIS-2016/JPEGImages/480p'#'/home/ying/tracking/pdb_results/FBMS-results'
    origin_path = '/home/xiankai/work/DAVIS-2016/Results/Segmentations/480p/COS-78.2'#'/home/xiankai/work/DAVIS-2016/Results/Segmentations/480p/ECCV'#'/media/xiankai/Data/segmentation/match-Weaksup_VideoSeg/result/test/davis_iteration_conf_sal_match_scale/COS/'
    out_folder = '/home/xiankai/work/DAVIS-2016/Results/Segmentations/480p/cvpr2019_crfs'#'/media/xiankai/Data/ECCV-crf'#'/home/xiankai/work/DAVIS-2016/Results/Segmentations/480p/davis_ICCV_new/'
    if not os.path.exists(out_folder):
        os.makedirs(out_folder)
    origin_file = listdir(origin_path)
    origin_file.sort()
    for i in range(0, len(origin_file)):
        d = origin_file[i]
        vidDir = join(davis_path, d)
        out_folder1 = join(out_folder,'f'+str(scale)+str(g_dim)+str(g_factor)+'_'+'s'+str(s_dim)+'_'+'c'+str(C_dim)+str(c_factor))
        resDir = join(out_folder1, d)
        if not os.path.exists(resDir):
                os.makedirs(resDir)
        rgb_file = listdir(vidDir)
        rgb_file.sort()
        for ii in range(0,len(rgb_file)):  
            f = rgb_file[ii]
            img = imread(join(vidDir, f))
            segDir = join(origin_path, d)
            frameName = str.split(f, '.')[0]
            anno_rgb = imread(segDir + '/' + frameName + '.png').astype(np.uint32)
            min_val = np.min(anno_rgb.ravel())
            max_val = np.max(anno_rgb.ravel())
            out = (anno_rgb.astype('float') - min_val) / (max_val - min_val)
            labels = np.zeros((2, img.shape[0], img.shape[1]))
            labels[1, :, :] = out
            labels[0, :, :] = 1 - out
    
            colors = [0, 255]
            colorize = np.empty((len(colors), 1), np.uint8)
            colorize[:,0] = colors
            n_labels = 2
    
            crf = dcrf.DenseCRF(img.shape[1] * img.shape[0], n_labels)
    
            U = unary_from_softmax(labels,scale)
            crf.setUnaryEnergy(U)
    
            feats = create_pairwise_gaussian(sdims=(g_dim, g_dim), shape=img.shape[:2])
    
            crf.addPairwiseEnergy(feats, compat=g_factor,
                            kernel=dcrf.DIAG_KERNEL,
                            normalization=dcrf.NORMALIZE_SYMMETRIC)
    
            feats = create_pairwise_bilateral(sdims=(s_dim,s_dim), schan=(C_dim, C_dim, C_dim),# 30,5
                                          img=img, chdim=2)
    
            crf.addPairwiseEnergy(feats, compat=c_factor,
                            kernel=dcrf.DIAG_KERNEL,
                            normalization=dcrf.NORMALIZE_SYMMETRIC)
    
            #Q = crf.inference(5)
            Q, tmp1, tmp2 = crf.startInference()
            for i in range(5):
                #print("KL-divergence at {}: {}".format(i, crf.klDivergence(Q)))
                crf.stepInference(Q, tmp1, tmp2)
    
            MAP = np.argmax(Q, axis=0)
            MAP = colorize[MAP]
            
            imsave(resDir + '/' + frameName + '.png', MAP.reshape(anno_rgb.shape))
            print ("Saving: " + resDir + '/' + frameName + '.png')
scales = [1]#[0.5,1]#[0.1,0.3,0.5,0.6]#[0.5, 1.0]
g_dims = [1]#[1,3]#[1,3]
g_factors =[5]#[3,5,10] #[ 3, 5,10]
s_dims = [10,15,20] #[5,10,20]#[11, 12, 13]#[9,10,11] #10
Cs = [7]#[5]#[8]# [ 7,8,9,10] #8
b_factors = [8,9,10]
for scale in scales: 
    for g_dim in g_dims:
        for ii in range(0,len(g_factors)):
            g_factor = g_factors[ii]
            for jj in range(0,len(s_dims)):
                s_dim = s_dims[jj]
                for cs in Cs:
                    p1 = Process(target = worker, args = (scale, g_dim, g_factor,s_dim,cs, b_factors[0]))
                    p2 = Process(target = worker, args = (scale, g_dim, g_factor,s_dim,cs, b_factors[1]))
                    p3 = Process(target = worker, args = (scale, g_dim, g_factor,s_dim,cs, b_factors[2]))
                    #p4 = Process(target = worker, args = (scale, g_dim, g_factor,s_dim,cs, 4))
                    #p5 = Process(target = worker, args = (scale, g_dim, g_factor,s_dim,cs, 1))
                    #p6 = Process(target = worker, args = (scale, g_dim, g_factor,s_dim,cs, 1))
                    
                    p1.start()
                    p2.start()
                    p3.start()
                    #p4.start()
                    #p5.start()
                    #p6.start()
            
            


================================================
FILE: pretrained/deep_labv3/readme.md
================================================



================================================
FILE: test_coattention_conf.py
================================================
# -*- coding: utf-8 -*-
"""
Created on Mon Sep 17 17:53:20 2018

@author: carri
"""

import argparse
import torch
import torch.nn as nn
from torch.utils import data
import numpy as np
import pickle
import cv2
from torch.autograd import Variable
import torch.optim as optim
import scipy.misc
import torch.backends.cudnn as cudnn
import sys
import os
import os.path as osp
from dataloaders import PairwiseImg_test as db
#from dataloaders import StaticImg as db #采用voc dataset的数据设置格式方法
import matplotlib.pyplot as plt
import random
import timeit
from PIL import Image
from collections import OrderedDict
import matplotlib.pyplot as plt
import torch.nn as nn
#from utils.colorize_mask import cityscapes_colorize_mask, VOCColorize
#import pydensecrf.densecrf as dcrf
#from pydensecrf.utils import unary_from_softmax, create_pairwise_bilateral, create_pairwise_gaussian
from deeplab.siamese_model_conf import CoattentionNet
from torchvision.utils import save_image

def get_arguments():
    """Parse all the arguments provided from the CLI.
    
    Returns:
      A list of parsed arguments.
    """
    parser = argparse.ArgumentParser(description="PSPnet")
    parser.add_argument("--dataset", type=str, default='cityscapes',
                        help="voc12, cityscapes, or pascal-context")

    # GPU configuration
    parser.add_argument("--cuda", default=True, help="Run on CPU or GPU")
    parser.add_argument("--gpus", type=str, default="0",
                        help="choose gpu device.")
    parser.add_argument("--seq_name", default = 'bmx-bumps')
    parser.add_argument("--use_crf", default = 'True')
    parser.add_argument("--sample_range", default =5)
    
    return parser.parse_args()

def configure_dataset_model(args):
    if args.dataset == 'voc12':
        args.data_dir ='/home/wty/AllDataSet/VOC2012'  #Path to the directory containing the PASCAL VOC dataset
        args.data_list = './dataset/list/VOC2012/test.txt'  #Path to the file listing the images in the dataset
        args.img_mean = np.array((104.00698793,116.66876762,122.67891434), dtype=np.float32) 
        #RBG mean, first subtract mean and then change to BGR
        args.ignore_label = 255   #The index of the label to ignore during the training
        args.num_classes = 21  #Number of classes to predict (including background)
        args.restore_from = './snapshots/voc12/psp_voc12_14.pth'  #Where restore model parameters from
        args.save_segimage = True
        args.seg_save_dir = "./result/test/VOC2012"
        args.corp_size =(505, 505)
        
    elif args.dataset == 'davis': 
        args.batch_size = 1# 1 card: 5, 2 cards: 10 Number of images sent to the network in one step, 16 on paper
        args.maxEpoches = 15 # 1 card: 15, 2 cards: 15 epoches, equal to 30k iterations, max iterations= maxEpoches*len(train_aug)/batch_size_per_gpu'),
        args.data_dir = 'your_path/DAVIS-2016'   # 37572 image pairs
        args.data_list = 'your_path/DAVIS-2016/test_seqs.txt'  # Path to the file listing the images in the dataset
        args.ignore_label = 255     #The index of the label to ignore during the training
        args.input_size = '473,473' #Comma-separated string with height and width of images
        args.num_classes = 2      #Number of classes to predict (including background)
        args.img_mean = np.array((104.00698793,116.66876762,122.67891434), dtype=np.float32)       # saving model file and log record during the process of training
        args.restore_from = './your_path.pth' #resnet50-19c8e357.pth''/home/xiankai/PSPNet_PyTorch/snapshots/davis/psp_davis_0.pth' #
        args.snapshot_dir = './snapshots/davis_iteration/'          #Where to save snapshots of the model
        args.save_segimage = True
        args.seg_save_dir = "./result/test/davis_iteration_conf"
        args.vis_save_dir = "./result/test/davis_vis"
        args.corp_size =(473, 473)
        
    else:
        print("dataset error")

def convert_state_dict(state_dict):
    """Converts a state dict saved from a dataParallel module to normal 
       module state_dict inplace
       :param state_dict is the loaded DataParallel model_state
       You probably saved the model using nn.DataParallel, which stores the model in module, and now you are trying to load it 
       without DataParallel. You can either add a nn.DataParallel temporarily in your network for loading purposes, or you can 
       load the weights file, create a new ordered dict without the module prefix, and load it back 
    """
    state_dict_new = OrderedDict()
    #print(type(state_dict))
    for k, v in state_dict.items():
        #print(k)
        name = k[7:] # remove the prefix module.
        # My heart is broken, the pytorch have no ability to do with the problem.
        state_dict_new[name] = v
        if name == 'linear_e.weight':
            np.save('weight_matrix.npy',v.cpu().numpy())
    return state_dict_new

def sigmoid(inX): 
    return 1.0/(1+np.exp(-inX))#定义一个sigmoid方法,其本质就是1/(1+e^-x)

def main():
    args = get_arguments()
    print("=====> Configure dataset and model")
    configure_dataset_model(args)
    print(args)
    model = CoattentionNet(num_classes=args.num_classes)
    
    saved_state_dict = torch.load(args.restore_from, map_location=lambda storage, loc: storage)
    #print(saved_state_dict.keys())
    #model.load_state_dict({k.replace('pspmodule.',''):v for k,v in torch.load(args.restore_from)['state_dict'].items()})
    model.load_state_dict( convert_state_dict(saved_state_dict["model"]) ) #convert_state_dict(saved_state_dict["model"])

    model.eval()
    model.cuda()
    if args.dataset == 'voc12':
        testloader = data.DataLoader(VOCDataTestSet(args.data_dir, args.data_list, crop_size=(505, 505),mean= args.img_mean), 
                                    batch_size=1, shuffle=False, pin_memory=True)
        interp = nn.Upsample(size=(505, 505), mode='bilinear')
        voc_colorize = VOCColorize()
        
    elif args.dataset == 'davis':  #for davis 2016
        db_test = db.PairwiseImg(train=False, inputRes=(473,473), db_root_dir=args.data_dir,  transform=None, seq_name = None, sample_range = args.sample_range) #db_root_dir() --> '/path/to/DAVIS-2016' train path
        testloader = data.DataLoader(db_test, batch_size= 1, shuffle=False, num_workers=0)
        #voc_colorize = VOCColorize()
    else:
        print("dataset error")

    data_list = []

    if args.save_segimage:
        if not os.path.exists(args.seg_save_dir) and not os.path.exists(args.vis_save_dir):
            os.makedirs(args.seg_save_dir)
            os.makedirs(args.vis_save_dir)
    print("======> test set size:", len(testloader))
    my_index = 0
    old_temp=''
    for index, batch in enumerate(testloader):
        print('%d processd'%(index))
        target = batch['target']
        #search = batch['search']
        temp = batch['seq_name']
        args.seq_name=temp[0]
        print(args.seq_name)
        if old_temp==args.seq_name:
            my_index = my_index+1
        else:
            my_index = 0
        output_sum = 0   
        for i in range(0,args.sample_range):  
            search = batch['search'+'_'+str(i)]
            search_im = search
            #print(search_im.size())
            output = model(Variable(target, volatile=True).cuda(),Variable(search_im, volatile=True).cuda())
            #print(output[0]) # output有两个
            output_sum = output_sum + output[0].data[0,0].cpu().numpy() #分割那个分支的结果
            #np.save('infer'+str(i)+'.npy',output1)
            #output2 = output[1].data[0, 0].cpu().numpy() #interp'
        
        output1 = output_sum/args.sample_range
     
        first_image = np.array(Image.open(args.data_dir+'/JPEGImages/480p/blackswan/00000.jpg'))
        original_shape = first_image.shape 
        output1 = cv2.resize(output1, (original_shape[1],original_shape[0]))

        mask = (output1*255).astype(np.uint8)
        #print(mask.shape[0])
        mask = Image.fromarray(mask)
        

        if args.dataset == 'voc12':
            print(output.shape)
            print(size)
            output = output[:,:size[0],:size[1]]
            output = output.transpose(1,2,0)
            output = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)
            if args.save_segimage:
                seg_filename = os.path.join(args.seg_save_dir, '{}.png'.format(name[0]))
                color_file = Image.fromarray(voc_colorize(output).transpose(1, 2, 0), 'RGB')
                color_file.save(seg_filename)
                
        elif args.dataset == 'davis':
            
            save_dir_res = os.path.join(args.seg_save_dir, 'Results', args.seq_name)
            old_temp=args.seq_name
            if not os.path.exists(save_dir_res):
                os.makedirs(save_dir_res)
            if args.save_segimage:   
                my_index1 = str(my_index).zfill(5)
                seg_filename = os.path.join(save_dir_res, '{}.png'.format(my_index1))
                #color_file = Image.fromarray(voc_colorize(output).transpose(1, 2, 0), 'RGB')
                mask.save(seg_filename)
                #np.concatenate((torch.zeros(1, 473, 473), mask, torch.zeros(1, 512, 512)),axis = 0)
                #save_image(output1 * 0.8 + target.data, args.vis_save_dir, normalize=True)
        else:
            print("dataset error")
    

if __name__ == '__main__':
    main()


================================================
FILE: test_iteration_conf_group.py
================================================
# -*- coding: utf-8 -*-
"""
Created on Mon Sep 17 17:53:20 2018

@author: carri
"""

import argparse
import torch
import torch.nn as nn
from torch.utils import data
import numpy as np
import pickle
import cv2
from torch.autograd import Variable
import torch.optim as optim
import scipy.misc
import torch.backends.cudnn as cudnn
import sys
import os
import os.path as osp
from dataloaders import PairwiseImg_video_test_try as db
#from dataloaders import StaticImg as db #采用voc dataset的数据设置格式方法
import matplotlib.pyplot as plt
import random
import timeit
from PIL import Image
from collections import OrderedDict
import matplotlib.pyplot as plt
import torch.nn as nn
from utils.colorize_mask import cityscapes_colorize_mask, VOCColorize
#import pydensecrf.densecrf as dcrf
#from pydensecrf.utils import unary_from_softmax, create_pairwise_bilateral, create_pairwise_gaussian
from deeplab.siamese_model_conf_try import CoattentionNet
from torchvision.utils import save_image

def get_arguments():
    """Parse all the arguments provided from the CLI.
    
    Returns:
      A list of parsed arguments.
    """
    parser = argparse.ArgumentParser(description="PSPnet")
    parser.add_argument("--dataset", type=str, default='cityscapes',
                        help="voc12, cityscapes, or pascal-context")

    # GPU configuration
    parser.add_argument("--cuda", default=True, help="Run on CPU or GPU")
    parser.add_argument("--gpus", type=str, default="0",
                        help="choose gpu device.")
    parser.add_argument("--seq_name", default = 'bmx-bumps')
    parser.add_argument("--use_crf", default = 'True')
    parser.add_argument("--sample_range", default =3)
    
    return parser.parse_args()

def configure_dataset_model(args):

    args.batch_size = 1# 1 card: 5, 2 cards: 10 Number of images sent to the network in one step, 16 on paper
    args.maxEpoches = 15 # 1 card: 15, 2 cards: 15 epoches, equal to 30k iterations, max iterations= maxEpoches*len(train_aug)/batch_size_per_gpu'),
    args.data_dir = '/home/xiankai/work/DAVIS-2016'   # 37572 image pairs
    args.data_list = '/home/xiankai/work/DAVIS-2016/test_seqs.txt'  # Path to the file listing the images in the dataset
    args.ignore_label = 255     #The index of the label to ignore during the training
    args.input_size = '473,473' #Comma-separated string with height and width of images
    args.num_classes = 2      #Number of classes to predict (including background)
    args.img_mean = np.array((104.00698793,116.66876762,122.67891434), dtype=np.float32)       # saving model file and log record during the process of training
    args.restore_from = './co_attention_davis_43.pth' #resnet50-19c8e357.pth''/home/xiankai/PSPNet_PyTorch/snapshots/davis/psp_davis_0.pth' #
    args.snapshot_dir = './snapshots/davis_iteration/'          #Where to save snapshots of the model
    args.save_segimage = True
    args.seg_save_dir = "./result/test/davis_iteration_conf_try"
    args.vis_save_dir = "./result/test/davis_vis"
    args.corp_size =(473, 473)

def convert_state_dict(state_dict):
    """Converts a state dict saved from a dataParallel module to normal 
       module state_dict inplace
       :param state_dict is the loaded DataParallel model_state
       You probably saved the model using nn.DataParallel, which stores the model in module, and now you are trying to load it 
       without DataParallel. You can either add a nn.DataParallel temporarily in your network for loading purposes, or you can 
       load the weights file, create a new ordered dict without the module prefix, and load it back 
    """
    state_dict_new = OrderedDict()
    #print(type(state_dict))
    for k, v in state_dict.items():
        #print(k)
        name = k[7:] # remove the prefix module.
        # My heart is broken, the pytorch have no ability to do with the problem.
        state_dict_new[name] = v
        if name == 'linear_e.weight':
            np.save('weight_matrix.npy',v.cpu().numpy())
    return state_dict_new

def sigmoid(inX): 
    return 1.0/(1+np.exp(-inX))#定义一个sigmoid方法,其本质就是1/(1+e^-x)

def main():
    args = get_arguments()
    print("=====> Configure dataset and model")
    configure_dataset_model(args)
    print(args)

    print("=====> Set GPU for training")
    if args.cuda:
        print("====> Use gpu id: '{}'".format(args.gpus))
        os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
        if not torch.cuda.is_available():
            raise Exception("No GPU found or Wrong gpu id, please run without --cuda")
    model = CoattentionNet(num_classes=args.num_classes, nframes = args.sample_range)
    for param in model.parameters():
        param.requires_grad = False

    saved_state_dict = torch.load(args.restore_from, map_location=lambda storage, loc: storage)
    #print(saved_state_dict.keys())
    #model.load_state_dict({k.replace('pspmodule.',''):v for k,v in torch.load(args.restore_from)['state_dict'].items()})
    model.load_state_dict( convert_state_dict(saved_state_dict["model"]) ) #convert_state_dict(saved_state_dict["model"])

    model.eval()
    model.cuda()

    db_test = db.PairwiseImg(train=False, inputRes=(473,473), db_root_dir=args.data_dir,  transform=None, seq_name = None, sample_range = args.sample_range) #db_root_dir() --> '/path/to/DAVIS-2016' train path
    testloader = data.DataLoader(db_test, batch_size= 1, shuffle=False, num_workers=0)
    voc_colorize = VOCColorize()

    data_list = []

    if args.save_segimage:
        if not os.path.exists(args.seg_save_dir) and not os.path.exists(args.vis_save_dir):
            os.makedirs(args.seg_save_dir)
            #os.makedirs(args.vis_save_dir)
    print("======> test set size:", len(testloader))
    my_index = 0
    old_temp=''
    for index, batch in enumerate(testloader):
        print('%d processd'%(index))
        target = batch['target']
        np.save('target.npy', target.float().data)
        #search = batch['search']
        temp = batch['seq_name']
        args.seq_name=temp[0]
        print(args.seq_name)
        if old_temp==args.seq_name:
            my_index = my_index+1
        else:
            my_index = 0
        output_sum = 0   
        for i in range(0,1):
            search = batch['search'+'_'+str(i)]
            search_im = search
            print('input size:', search_im.size(),len(search.size()))
            if len(search.size()) <5:
                search_im = search_im.unsqueeze(0)
            output = model(Variable(target, volatile=True).cuda(),Variable(search_im, volatile=True).cuda())
            #print(output[0]) # output有两个
            output_sum = output_sum + output[0].data[0,0].cpu().numpy() #分割那个分支的结果

            #np.save('infer'+str(i)+'.npy',output1)
            #output2 = output[1].data[0, 0].cpu().numpy() #interp'
        
        output1 = output_sum#/args.sample_range
        #target_mask = output[3].data[0,0].cpu().numpy()
        #print('output size:', output1.shape, type(output1))
        first_image = np.array(Image.open(args.data_dir+'/JPEGImages/480p/blackswan/00000.jpg'))
        original_shape = first_image.shape 
        output1 = cv2.resize(output1, (original_shape[1],original_shape[0]))
        #output2 = cv2.resize(target_mask, (original_shape[1], original_shape[0]))
        mask = (output1*255).astype(np.uint8)
        #target_mask = (output2*255).astype(np.uint8)
        mask = Image.fromarray(mask)
        #target_mask = Image.fromarray(target_mask)

        save_dir_res = os.path.join(args.seg_save_dir, 'Results', args.seq_name)
        old_temp=args.seq_name
        if not os.path.exists(save_dir_res):
            os.makedirs(save_dir_res)
        if args.save_segimage:
            my_index1 = str(my_index).zfill(5)
            seg_filename = os.path.join(save_dir_res, '{}.png'.format(my_index1))
            gate_filename = os.path.join(save_dir_res, '{}_gate.png'.format(my_index1))
            mask.save(seg_filename)
            #target_mask.save(gate_filename)

if __name__ == '__main__':
    main()


================================================
FILE: train_iteration_conf.py
================================================
# -*- coding: utf-8 -*-
"""
Created on Sat Sep 15 10:52:26 2018

@author: carri
"""
#区别于deeplab_co_attention_concat在于采用了新的model(siamese_model_concat_new)来train

import argparse
import torch
import torch.nn as nn
from torch.utils import data
import numpy as np
import pickle
import cv2
from torch.autograd import Variable
import torch.optim as optim
import scipy.misc
import torch.backends.cudnn as cudnn
import sys
import os
#from utils.balanced_BCE import class_balanced_cross_entropy_loss
import os.path as osp
#from psp.model import PSPNet
#from dataloaders import davis_2016 as db
from dataloaders import PairwiseImg_video as db #采用voc dataset的数据设置格式方法
import matplotlib.pyplot as plt
import random
import timeit
#from psp.model1 import CoattentionNet  #基于pspnet搭建的co-attention 模型
from deeplab.siamese_model_conf import CoattentionNet #siame_model 是直接将attend的model之后的结果输出
#from deeplab.utils import get_1x_lr_params, get_10x_lr_params#, adjust_learning_rate #, loss_calc
start = timeit.default_timer()

def get_arguments():
    """Parse all the arguments provided from the CLI.
    
    Returns:
      A list of parsed arguments.
    """
    parser = argparse.ArgumentParser(description="PSPnet Network")

    # optimatization configuration
    parser.add_argument("--is-training", action="store_true", 
                        help="Whether to updates the running means and variances during the training.")
    parser.add_argument("--learning-rate", type=float, default= 0.00025, 
                        help="Base learning rate for training with polynomial decay.") #0.001
    parser.add_argument("--weight-decay", type=float, default= 0.0005, 
                        help="Regularization parameter for L2-loss.")  # 0.0005
    parser.add_argument("--momentum", type=float, default= 0.9, 
                        help="Momentum component of the optimiser.")
    parser.add_argument("--power", type=float, default= 0.9, 
                        help="Decay parameter to compute the learning rate.")
    # dataset information
    parser.add_argument("--dataset", type=str, default='cityscapes',
                        help="voc12, cityscapes, or pascal-context.")
    parser.add_argument("--random-mirror", action="store_true",
                        help="Whether to randomly mirror the inputs during the training.")
    parser.add_argument("--random-scale", action="store_true",
                        help="Whether to randomly scale the inputs during the training.")

    parser.add_argument("--not-restore-last", action="store_true",
                        help="Whether to not restore last (FC) layers.")
    parser.add_argument("--random-seed", type=int, default= 1234,
                        help="Random seed to have reproducible results.")
    parser.add_argument('--logFile', default='log.txt', 
                        help='File that stores the training and validation logs')
    # GPU configuration
    parser.add_argument("--cuda", default=True, help="Run on CPU or GPU")
    parser.add_argument("--gpus", type=str, default="3", help="choose gpu device.") #使用3号GPU


    return parser.parse_args()

args = get_arguments()


def configure_dataset_init_model(args):
    if args.dataset == 'voc12':

        args.batch_size = 10# 1 card: 5, 2 cards: 10 Number of images sent to the network in one step, 16 on paper
        args.maxEpoches = 15 # 1 card: 15, 2 cards: 15 epoches, equal to 30k iterations, max iterations= maxEpoches*len(train_aug)/batch_size_per_gpu'),
        args.data_dir = '/home/wty/AllDataSet/VOC2012'   # Path to the directory containing the PASCAL VOC dataset
        args.data_list = './dataset/list/VOC2012/train_aug.txt'  # Path to the file listing the images in the dataset
        args.ignore_label = 255     #The index of the label to ignore during the training
        args.input_size = '473,473' #Comma-separated string with height and width of images
        args.num_classes = 21      #Number of classes to predict (including background)

        args.img_mean = np.array((104.00698793,116.66876762,122.67891434), dtype=np.float32)
        # saving model file and log record during the process of training

        #Where restore model pretrained on other dataset, such as COCO.")
        args.restore_from = './pretrained/MS_DeepLab_resnet_pretrained_COCO_init.pth'
        args.snapshot_dir = './snapshots/voc12/'          #Where to save snapshots of the model
        args.resume = './snapshots/voc12/psp_voc12_3.pth' #checkpoint log file, helping recovering training
        
    elif args.dataset == 'davis': 
        args.batch_size = 16# 1 card: 5, 2 cards: 10 Number of images sent to the network in one step, 16 on paper
        args.maxEpoches = 60 # 1 card: 15, 2 cards: 15 epoches, equal to 30k iterations, max iterations= maxEpoches*len(train_aug)/batch_size_per_gpu'),
        args.data_dir = '/home/ubuntu/xiankai/dataset/DAVIS-2016'   # 37572 image pairs
        args.img_dir = '/home/ubuntu/xiankai/dataset/images'
        args.data_list = './dataset/list/VOC2012/train_aug.txt'  # Path to the file listing the images in the dataset
        args.ignore_label = 255     #The index of the label to ignore during the training
        args.input_size = '473,473' #Comma-separated string with height and width of images
        args.num_classes = 2      #Number of classes to predict (including background)
        args.img_mean = np.array((104.00698793,116.66876762,122.67891434), dtype=np.float32)       # saving model file and log record during the process of training
        #Where restore model pretrained on other dataset, such as COCO.")
        args.restore_from = './pretrained/deep_labv3/deeplab_davis_12_0.pth' #resnet50-19c8e357.pth''/home/xiankai/PSPNet_PyTorch/snapshots/davis/psp_davis_0.pth' #
        args.snapshot_dir = './snapshots/davis_iteration_conf/'          #Where to save snapshots of the model
        args.resume = './snapshots/davis/co_attention_davis_124.pth' #checkpoint log file, helping recovering training
		
    elif args.dataset == 'cityscapes':
        args.batch_size = 8   #Number of images sent to the network in one step, batch_size/num_GPU=2
        args.maxEpoches = 60 #epoch nums, 60 epoches is equal to 90k iterations, max iterations= maxEpoches*len(train)/batch_size')
        # 60x2975/2=89250 ~= 90k, single_GPU_batch_size=2
        args.data_dir = '/home/wty/AllDataSet/CityScapes'   # Path to the directory containing the PASCAL VOC dataset
        args.data_list = './dataset/list/Cityscapes/cityscapes_train_list.txt'  # Path to the file listing the images in the dataset
        args.ignore_label = 255     #The index of the label to ignore during the training
        args.input_size = '720,720' #Comma-separated string with height and width of images
        args.num_classes = 19      #Number of classes to predict (including background)

        args.img_mean = np.array((73.15835921, 82.90891754, 72.39239876), dtype=np.float32)
        # saving model file and log record during the process of training

        #Where restore model pretrained on other dataset, such as coarse cityscapes
        args.restore_from = './pretrained/resnet101_pretrained_for_cityscapes.pth'
        args.snapshot_dir = './snapshots/cityscapes/'          #Where to save snapshots of the model
        args.resume = './snapshots/cityscapes/psp_cityscapes_12_3.pth' #checkpoint log file, helping recovering training
       
    else:
        print("dataset error")

def adjust_learning_rate(optimizer, i_iter, epoch, max_iter):
    """Sets the learning rate to the initial LR divided by 5 at 60th, 120th and 160th epochs"""
    
    lr = lr_poly(args.learning_rate, i_iter, max_iter, args.power, epoch)
    optimizer.param_groups[0]['lr'] = lr
    if i_iter%3 ==0:
        optimizer.param_groups[0]['lr'] = lr
        optimizer.param_groups[1]['lr'] = 0
    else:
        optimizer.param_groups[0]['lr'] = 0.01*lr
        optimizer.param_groups[1]['lr'] = lr * 10
        
    return lr

def loss_calc1(pred, label):
    """
    This function returns cross entropy loss for semantic segmentation
    """
    labels = torch.ge(label, 0.5).float()
#    
    batch_size = label.size()
    #print(batch_size)
    num_labels_pos = torch.sum(labels) 
#    
    batch_1 =  batch_size[0]* batch_size[2]
    batch_1 = batch_1* batch_size[3]
    weight_1 = torch.div(num_labels_pos, batch_1) # pos ratio
    weight_1 = torch.reciprocal(weight_1)
    #print(num_labels_pos, batch_1)
    weight_2 = torch.div(batch_1-num_labels_pos, batch_1)
    #print('postive ratio', weight_2, weight_1)
    weight_22 = torch.mul(weight_1,  torch.ones(batch_size[0], batch_size[1], batch_size[2], batch_size[3]).cuda())
    #weight_11 = torch.mul(weight_1,  torch.ones(batch_size[0], batch_size[1], batch_size[2]).cuda())
    criterion = torch.nn.BCELoss(weight = weight_22)#weight = torch.Tensor([0,1]) .cuda() #torch.nn.CrossEntropyLoss(ignore_index=args.ignore_label).cuda()
    #loss = class_balanced_cross_entropy_loss(pred, label).cuda()
        
    return criterion(pred, label)

def loss_calc2(pred, label):
    """
    This function returns cross entropy loss for semantic segmentation
    """
    # out shape batch_size x channels x h x w -> batch_size x channels x h x w
    # label shape h x w x 1 x batch_size  -> batch_size x 1 x h x w
    # Variable(label.long()).cuda()
    criterion = torch.nn.L1Loss()#.cuda() #torch.nn.CrossEntropyLoss(ignore_index=args.ignore_label).cuda()
    
    return criterion(pred, label)



def get_1x_lr_params(model):
    """
    This generator returns all the parameters of the net except for 
    the last classification layer. Note that for each batchnorm layer, 
    requires_grad is set to False in deeplab_resnet.py, therefore this function does not return 
    any batchnorm parameter
    """
    b = []
    if torch.cuda.device_count() == 1:
        #b.append(model.encoder.conv1)
        #b.append(model.encoder.bn1)
        #b.append(model.encoder.layer1)
        #b.append(model.encoder.layer2)
        #b.append(model.encoder.layer3)
        #b.append(model.encoder.layer4)
        b.append(model.encoder.layer5)
    else:
        b.append(model.module.encoder.conv1)
        b.append(model.module.encoder.bn1)
        b.append(model.module.encoder.layer1)
        b.append(model.module.encoder.layer2)
        b.append(model.module.encoder.layer3)
        b.append(model.module.encoder.layer4)
        b.append(model.module.encoder.layer5)
        b.append(model.module.encoder.main_classifier)
    for i in range(len(b)):
        for j in b[i].modules():
            jj = 0
            for k in j.parameters():
                jj+=1
                if k.requires_grad:
                    yield k


def get_10x_lr_params(model):
    """
    This generator returns all the parameters for the last layer of the net,
    which does the classification of pixel into classes
    """
    b = []
    if torch.cuda.device_count() == 1:
        b.append(model.linear_e.parameters())
        b.append(model.main_classifier.parameters())
    else:
        #b.append(model.module.encoder.layer5.parameters())
        b.append(model.module.linear_e.parameters())
        b.append(model.module.conv1.parameters())
        b.append(model.module.conv2.parameters())
        b.append(model.module.gate.parameters())
        b.append(model.module.bn1.parameters())
        b.append(model.module.bn2.parameters())   
        b.append(model.module.main_classifier1.parameters())
        b.append(model.module.main_classifier2.parameters())
        
    for j in range(len(b)):
        for i in b[j]:
            yield i
            
def lr_poly(base_lr, iter, max_iter, power, epoch):
    if epoch<=2:
        factor = 1
    elif epoch>2 and epoch< 6:
        factor = 1
    else:
        factor = 0.5
    return base_lr*factor*((1-float(iter)/max_iter)**(power))


def netParams(model):
    '''
    Computing total network parameters
    Args:
       model: model
    return: total network parameters
    '''
    total_paramters = 0
    for parameter in model.parameters():
        i = len(parameter.size())
        #print(parameter.size())
        p = 1
        for j in range(i):
            p *= parameter.size(j)
        total_paramters += p

    return total_paramters

def main():
    
    
    print("=====> Configure dataset and pretrained model")
    configure_dataset_init_model(args)
    print(args)

    print("    current dataset:  ", args.dataset)
    print("    init model: ", args.restore_from)
    print("=====> Set GPU for training")
    if args.cuda:
        print("====> Use gpu id: '{}'".format(args.gpus))
        os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
        if not torch.cuda.is_available():
            raise Exception("No GPU found or Wrong gpu id, please run without --cuda")
    # Select which GPU, -1 if CPU
    #gpu_id = args.gpus
    #device = torch.device("cuda:"+str(gpu_id) if torch.cuda.is_available() else "cpu")
    print("=====> Random Seed: ", args.random_seed)
    torch.manual_seed(args.random_seed)
    if args.cuda:
        torch.cuda.manual_seed(args.random_seed) 

    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)

    cudnn.enabled = True

    print("=====> Building network")
    saved_state_dict = torch.load(args.restore_from)
    model = CoattentionNet(num_classes=args.num_classes)
    #print(model)
    new_params = model.state_dict().copy()
    for i in saved_state_dict["model"]:
        #Scale.layer5.conv2d_list.3.weight
        i_parts = i.split('.') # 针对多GPU的情况
        #i_parts.pop(1)
        #print('i_parts:  ', '.'.join(i_parts[1:-1]))
        #if  not i_parts[1]=='main_classifier': #and not '.'.join(i_parts[1:-1]) == 'layer5.bottleneck' and not '.'.join(i_parts[1:-1]) == 'layer5.bn':  #init model pretrained on COCO, class name=21, layer5 is ASPP
        new_params['encoder'+'.'+'.'.join(i_parts[1:])] = saved_state_dict["model"][i]
            #print('copy {}'.format('.'.join(i_parts[1:])))
    
   
    print("=====> Loading init weights,  pretrained COCO for VOC2012, and pretrained Coarse cityscapes for cityscapes")
 
            
    model.load_state_dict(new_params) #只用到resnet的第5个卷积层的参数
    #print(model.keys())
    if args.cuda:
        #model.to(device)
        if torch.cuda.device_count()>1:
            print("torch.cuda.device_count()=",torch.cuda.device_count())
            model = torch.nn.DataParallel(model).cuda()  #multi-card data parallel
        else:
            print("single GPU for training")
            model = model.cuda()  #1-card data parallel
    start_epoch=0
    
    print("=====> Whether resuming from a checkpoint, for continuing training")
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            start_epoch = checkpoint["epoch"] 
            model.load_state_dict(checkpoint["model"])
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))


    model.train()
    cudnn.benchmark = True

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)
    
    print('=====> Computing network parameters')
    total_paramters = netParams(model)
    print('Total network parameters: ' + str(total_paramters))
 
    print("=====> Preparing training data")
    if args.dataset == 'voc12':
        trainloader = data.DataLoader(VOCDataSet(args.data_dir, args.data_list, max_iters=None, crop_size=input_size, 
                                                 scale=args.random_scale, mirror=args.random_mirror, mean=args.img_mean), 
                                      batch_size= args.batch_size, shuffle=True, num_workers=0, pin_memory=True, drop_last=True)
    elif args.dataset == 'cityscapes':
        trainloader = data.DataLoader(CityscapesDataSet(args.data_dir, args.data_list, max_iters=None, crop_size=input_size, 
                                                 scale=args.random_scale, mirror=args.random_mirror, mean=args.img_mean), 
                                      batch_size = args.batch_size, shuffle=True, num_workers=0, pin_memory=True, drop_last=True)
    elif args.dataset == 'davis':  #for davis 2016
        db_train = db.PairwiseImg(train=True, inputRes=input_size, db_root_dir=args.data_dir, img_root_dir=args.img_dir,  transform=None) #db_root_dir() --> '/path/to/DAVIS-2016' train path
        trainloader = data.DataLoader(db_train, batch_size= args.batch_size, shuffle=True, num_workers=0)
    else:
        print("dataset error")

    optimizer = optim.SGD([{'params': get_1x_lr_params(model), 'lr': 1*args.learning_rate },  #针对特定层进行学习,有些层不学习
                {'params': get_10x_lr_params(model), 'lr': 10*args.learning_rate}], 
                lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay)
    optimizer.zero_grad()


    
    logFileLoc = args.snapshot_dir + args.logFile
    if os.path.isfile(logFileLoc):
        logger = open(logFileLoc, 'a')
    else:
        logger = open(logFileLoc, 'w')
        logger.write("Parameters: %s" % (str(total_paramters)))
        logger.write("\n%s\t\t%s" % ('iter', 'Loss(train)\n'))
    logger.flush()

    print("=====> Begin to train")
    train_len=len(trainloader)
    print("  iteration numbers  of per epoch: ", train_len)
    print("  epoch num: ", args.maxEpoches)
    print("  max iteration: ", args.maxEpoches*train_len)
    
    for epoch in range(start_epoch, int(args.maxEpoches)):
        
        np.random.seed(args.random_seed + epoch)
        for i_iter, batch in enumerate(trainloader,0): #i_iter from 0 to len-1
            #print("i_iter=", i_iter, "epoch=", epoch)
            target, target_gt, search, search_gt = batch['target'], batch['target_gt'], batch['search'], batch['search_gt']
            images, labels = batch['img'], batch['img_gt']
            #print(labels.size())
            images.requires_grad_()
            images = Variable(images).cuda()
            labels = Variable(labels.float().unsqueeze(1)).cuda()
            
            target.requires_grad_()
            target = Variable(target).cuda()
            target_gt = Variable(target_gt.float().unsqueeze(1)).cuda()
            
            search.requires_grad_()
            search = Variable(search).cuda()
            search_gt = Variable(search_gt.float().unsqueeze(1)).cuda()
            
            optimizer.zero_grad()
            
            lr = adjust_learning_rate(optimizer, i_iter+epoch*train_len, epoch,
                    max_iter = args.maxEpoches * train_len)
            #print(images.size())
            if i_iter%3 ==0: #对于静态图片的训练
                
                pred1, pred2, pred3 = model(images, images)
                loss = 0.1*(loss_calc1(pred3, labels) + 0.8* loss_calc2(pred3, labels) )
                loss.backward()
                
            else:
                    
                pred1, pred2, pred3 = model(target, search)
                loss = loss_calc1(pred1, target_gt) + 0.8* loss_calc2(pred1, target_gt) + loss_calc1(pred2, search_gt) + 0.8* loss_calc2(pred2, search_gt)#class_balanced_cross_entropy_loss(pred, labels, size_average=False)
                loss.backward()
            
            optimizer.step()
                
            print("===> Epoch[{}]({}/{}): Loss: {:.10f}  lr: {:.5f}".format(epoch, i_iter, train_len, loss.data, lr))
            logger.write("Epoch[{}]({}/{}):     Loss: {:.10f}      lr: {:.5f}\n".format(epoch, i_iter, train_len, loss.data, lr))
            logger.flush()
                
        print("=====> saving model")
        state={"epoch": epoch+1, "model": model.state_dict()}
        torch.save(state, osp.join(args.snapshot_dir, 'co_attention_'+str(args.dataset)+"_"+str(epoch)+'.pth'))


    end = timeit.default_timer()
    print( float(end-start)/3600, 'h')
    logger.write("total training time: {:.2f} h\n".format(float(end-start)/3600))
    logger.close()


if __name__ == '__main__':
    main()


================================================
FILE: train_iteration_conf_group.py
================================================
# -*- coding: utf-8 -*-
"""
Created on Sat Sep 15 10:52:26 2018

@author: carri
"""
#区别于deeplab_co_attention_concat在于采用了新的model(siamese_model_concat_new)来train

import argparse
import torch
import torch.nn as nn
from torch.utils import data
import numpy as np
import pickle
import cv2
from torch.autograd import Variable
import torch.optim as optim
import scipy.misc
import torch.backends.cudnn as cudnn
import sys
import os
from utils.balanced_BCE import class_balanced_cross_entropy_loss
import os.path as osp
#from psp.model import PSPNet
#from dataloaders import davis_2016 as db
from dataloaders import PairwiseImg_video_try as db #采用voc dataset的数据设置格式方法
import matplotlib.pyplot as plt
import random
import timeit
#from psp.model1 import CoattentionNet  #基于pspnet搭建的co-attention 模型
from deeplab.siamese_model_conf_try import CoattentionNet #siame_model 是直接将attend的model之后的结果输出
#from deeplab.utils import get_1x_lr_params, get_10x_lr_params#, adjust_learning_rate #, loss_calc
start = timeit.default_timer()

def get_arguments():
    """Parse all the arguments provided from the CLI.
    
    Returns:
      A list of parsed arguments.
    """
    parser = argparse.ArgumentParser(description="PSPnet Network")

    # optimatization configuration
    parser.add_argument("--is-training", action="store_true", 
                        help="Whether to updates the running means and variances during the training.")
    parser.add_argument("--learning-rate", type=float, default= 0.00025, 
                        help="Base learning rate for training with polynomial decay.") #0.001
    parser.add_argument("--weight-decay", type=float, default= 0.0005, 
                        help="Regularization parameter for L2-loss.")  # 0.0005
    parser.add_argument("--momentum", type=float, default= 0.9, 
                        help="Momentum component of the optimiser.")
    parser.add_argument("--power", type=float, default= 0.9, 
                        help="Decay parameter to compute the learning rate.")
    # dataset information
    parser.add_argument("--dataset", type=str, default='cityscapes',
                        help="voc12, cityscapes, or pascal-context.")
    parser.add_argument("--random-mirror", action="store_true",
                        help="Whether to randomly mirror the inputs during the training.")
    parser.add_argument("--random-scale", action="store_true",
                        help="Whether to randomly scale the inputs during the training.")

    parser.add_argument("--not-restore-last", action="store_true",
                        help="Whether to not restore last (FC) layers.")
    parser.add_argument("--random-seed", type=int, default= 1234,
                        help="Random seed to have reproducible results.")
    parser.add_argument('--logFile', default='log.txt', 
                        help='File that stores the training and validation logs')
    # GPU configuration
    parser.add_argument("--cuda", default=True, help="Run on CPU or GPU")
    parser.add_argument("--gpus", type=str, default="3", help="choose gpu device.") #使用3号GPU


    return parser.parse_args()

args = get_arguments()


def configure_dataset_init_model(args):
    if args.dataset == 'voc12':

        args.batch_size = 10# 1 card: 5, 2 cards: 10 Number of images sent to the network in one step, 16 on paper
        args.maxEpoches = 15 # 1 card: 15, 2 cards: 15 epoches, equal to 30k iterations, max iterations= maxEpoches*len(train_aug)/batch_size_per_gpu'),
        args.data_dir = '/home/wty/AllDataSet/VOC2012'   # Path to the directory containing the PASCAL VOC dataset
        args.data_list = './dataset/list/VOC2012/train_aug.txt'  # Path to the file listing the images in the dataset
        args.ignore_label = 255     #The index of the label to ignore during the training
        args.input_size = '473,473' #Comma-separated string with height and width of images
        args.num_classes = 21      #Number of classes to predict (including background)

        args.img_mean = np.array((104.00698793,116.66876762,122.67891434), dtype=np.float32)
        # saving model file and log record during the process of training

        #Where restore model pretrained on other dataset, such as COCO.")
        args.restore_from = './pretrained/MS_DeepLab_resnet_pretrained_COCO_init.pth'
        args.snapshot_dir = './snapshots/voc12/'          #Where to save snapshots of the model
        args.resume = './snapshots/voc12/psp_voc12_3.pth' #checkpoint log file, helping recovering training
        
    elif args.dataset == 'davis': 
        args.batch_size = 16# 1 card: 5, 2 cards: 10 Number of images sent to the network in one step, 16 on paper
        args.maxEpoches = 60 # 1 card: 15, 2 cards: 15 epoches, equal to 30k iterations, max iterations= maxEpoches*len(train_aug)/batch_size_per_gpu'),
        args.data_dir = '/home/ubuntu/xiankai/dataset/DAVIS-2016'   # 37572 image pairs
        args.img_dir = '/home/ubuntu/xiankai/dataset/images'
        args.data_list = './dataset/list/VOC2012/train_aug.txt'  # Path to the file listing the images in the dataset
        args.ignore_label = 255     #The index of the label to ignore during the training
        args.input_size = '378, 378' #Comma-separated string with height and width of images
        args.num_classes = 2      #Number of classes to predict (including background)
        args.img_mean = np.array((104.00698793,116.66876762,122.67891434), dtype=np.float32)       # saving model file and log record during the process of training
        #Where restore model pretrained on other dataset, such as COCO.")
        args.restore_from = './pretrained/deep_labv3/deeplab_davis_12_0.pth' #resnet50-19c8e357.pth''/home/xiankai/PSPNet_PyTorch/snapshots/davis/psp_davis_0.pth' #
        args.snapshot_dir = './snapshots/davis_iteration_conf_try/'          #Where to save snapshots of the model
        args.resume = './snapshots/davis/co_attention_davis_124.pth' #checkpoint log file, helping recovering training
		
    elif args.dataset == 'cityscapes':
        args.batch_size = 8   #Number of images sent to the network in one step, batch_size/num_GPU=2
        args.maxEpoches = 60 #epoch nums, 60 epoches is equal to 90k iterations, max iterations= maxEpoches*len(train)/batch_size')
        # 60x2975/2=89250 ~= 90k, single_GPU_batch_size=2
        args.data_dir = '/home/wty/AllDataSet/CityScapes'   # Path to the directory containing the PASCAL VOC dataset
        args.data_list = './dataset/list/Cityscapes/cityscapes_train_list.txt'  # Path to the file listing the images in the dataset
        args.ignore_label = 255     #The index of the label to ignore during the training
        args.input_size = '720,720' #Comma-separated string with height and width of images
        args.num_classes = 19      #Number of classes to predict (including background)

        args.img_mean = np.array((73.15835921, 82.90891754, 72.39239876), dtype=np.float32)
        # saving model file and log record during the process of training

        #Where restore model pretrained on other dataset, such as coarse cityscapes
        args.restore_from = './pretrained/resnet101_pretrained_for_cityscapes.pth'
        args.snapshot_dir = './snapshots/cityscapes/'          #Where to save snapshots of the model
        args.resume = './snapshots/cityscapes/psp_cityscapes_12_3.pth' #checkpoint log file, helping recovering training
       
    else:
        print("dataset error")

def adjust_learning_rate(optimizer, i_iter, epoch, max_iter):
    """Sets the learning rate to the initial LR divided by 5 at 60th, 120th and 160th epochs"""
    
    lr = lr_poly(args.learning_rate, i_iter, max_iter, args.power, epoch)
    optimizer.param_groups[0]['lr'] = lr
    if i_iter%3 ==0:
        optimizer.param_groups[0]['lr'] = lr
        optimizer.param_groups[1]['lr'] = 0
    else:
        optimizer.param_groups[0]['lr'] = 0.01*lr
        optimizer.param_groups[1]['lr'] = lr * 10
        
    return lr

def loss_calc1(pred, label):
    """
    This function returns cross entropy loss for semantic segmentation
    """
    labels = torch.ge(label, 0.5).float()
#    
    batch_size = label.size()
    #print(batch_size)
    num_labels_pos = torch.sum(labels) 
#    
    batch_1 =  batch_size[0]* batch_size[2]
    batch_1 = batch_1* batch_size[3]
    weight_1 = torch.div(num_labels_pos, batch_1) # pos ratio
    weight_1 = torch.reciprocal(weight_1)
    #print(num_labels_pos, batch_1)
    weight_2 = torch.div(batch_1-num_labels_pos, batch_1)
    #print('postive ratio', weight_2, weight_1)
    weight_22 = torch.mul(weight_1,  torch.ones(batch_size[0], batch_size[1], batch_size[2], batch_size[3]).cuda())
    #weight_11 = torch.mul(weight_1,  torch.ones(batch_size[0], batch_size[1], batch_size[2]).cuda())
    criterion = torch.nn.BCELoss(weight = weight_22)#weight = torch.Tensor([0,1]) .cuda() #torch.nn.CrossEntropyLoss(ignore_index=args.ignore_label).cuda()
    #loss = class_balanced_cross_entropy_loss(pred, label).cuda()
        
    return criterion(pred, label)

def loss_calc2(pred, label):
    """
    This function returns cross entropy loss for semantic segmentation
    """
    # out shape batch_size x channels x h x w -> batch_size x channels x h x w
    # label shape h x w x 1 x batch_size  -> batch_size x 1 x h x w
    # Variable(label.long()).cuda()
    criterion = torch.nn.L1Loss()#.cuda() #torch.nn.CrossEntropyLoss(ignore_index=args.ignore_label).cuda()
    
    return criterion(pred, label)



def get_1x_lr_params(model):
    """
    This generator returns all the parameters of the net except for 
    the last classification layer. Note that for each batchnorm layer, 
    requires_grad is set to False in deeplab_resnet.py, therefore this function does not return 
    any batchnorm parameter
    """
    b = []
    if torch.cuda.device_count() == 1:
        #b.append(model.encoder.conv1)
        #b.append(model.encoder.bn1)
        #b.append(model.encoder.layer1)
        #b.append(model.encoder.layer2)
        #b.append(model.encoder.layer3)
        #b.append(model.encoder.layer4)
        b.append(model.encoder.layer5)
    else:
        b.append(model.module.encoder.conv1)
        b.append(model.module.encoder.bn1)
        b.append(model.module.encoder.layer1)
        b.append(model.module.encoder.layer2)
        b.append(model.module.encoder.layer3)
        b.append(model.module.encoder.layer4)
        b.append(model.module.encoder.layer5)
        b.append(model.module.encoder.main_classifier)
    for i in range(len(b)):
        for j in b[i].modules():
            jj = 0
            for k in j.parameters():
                jj+=1
                if k.requires_grad:
                    yield k


def get_10x_lr_params(model):
    """
    This generator returns all the parameters for the last layer of the net,
    which does the classification of pixel into classes
    """
    b = []
    if torch.cuda.device_count() == 1:
        b.append(model.linear_e.parameters())
        b.append(model.main_classifier.parameters())
    else:
        #b.append(model.module.encoder.layer5.parameters())
        b.append(model.module.linear_e.parameters())
        b.append(model.module.conv1.parameters())
        #b.append(model.module.conv2.parameters())
        b.append(model.module.gate.parameters())
        b.append(model.module.bn1.parameters())
        #b.append(model.module.bn2.parameters())
        b.append(model.module.main_classifier1.parameters())
        #b.append(model.module.main_classifier2.parameters())
        
    for j in range(len(b)):
        for i in b[j]:
            yield i
            
def lr_poly(base_lr, iter, max_iter, power, epoch):
    if epoch<=2:
        factor = 1
    elif epoch>2 and epoch< 6:
        factor = 1
    else:
        factor = 0.5
    return base_lr*factor*((1-float(iter)/max_iter)**(power))


def netParams(model):
    '''
    Computing total network parameters
    Args:
       model: model
    return: total network parameters
    '''
    total_paramters = 0
    for parameter in model.parameters():
        i = len(parameter.size())
        #print(parameter.size())
        p = 1
        for j in range(i):
            p *= parameter.size(j)
        total_paramters += p

    return total_paramters

def main():
    
    
    print("=====> Configure dataset and pretrained model")
    configure_dataset_init_model(args)
    print(args)

    print("    current dataset:  ", args.dataset)
    print("    init model: ", args.restore_from)
    print("=====> Set GPU for training")
    if args.cuda:
        print("====> Use gpu id: '{}'".format(args.gpus))
        os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
        if not torch.cuda.is_available():
            raise Exception("No GPU found or Wrong gpu id, please run without --cuda")
    # Select which GPU, -1 if CPU
    #gpu_id = args.gpus
    #device = torch.device("cuda:"+str(gpu_id) if torch.cuda.is_available() else "cpu")
    print("=====> Random Seed: ", args.random_seed)
    torch.manual_seed(args.random_seed)
    if args.cuda:
        torch.cuda.manual_seed(args.random_seed) 

    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)

    cudnn.enabled = True

    print("=====> Building network")
    saved_state_dict = torch.load(args.restore_from)
    model = CoattentionNet(num_classes=args.num_classes)
    #print(model)
    new_params = model.state_dict().copy()
    for i in saved_state_dict["model"]:
        #Scale.layer5.conv2d_list.3.weight
        i_parts = i.split('.') # 针对多GPU的情况
        #i_parts.pop(1)
        #print('i_parts:  ', '.'.join(i_parts[1:-1]))
        #if  not i_parts[1]=='main_classifier': #and not '.'.join(i_parts[1:-1]) == 'layer5.bottleneck' and not '.'.join(i_parts[1:-1]) == 'layer5.bn':  #init model pretrained on COCO, class name=21, layer5 is ASPP
        new_params['encoder'+'.'+'.'.join(i_parts[1:])] = saved_state_dict["model"][i]
            #print('copy {}'.format('.'.join(i_parts[1:])))
    
   
    print("=====> Loading init weights,  pretrained COCO for VOC2012, and pretrained Coarse cityscapes for cityscapes")
 
            
    model.load_state_dict(new_params) #只用到resnet的第5个卷积层的参数
    #print(model.keys())
    if args.cuda:
        #model.to(device)
        if torch.cuda.device_count()>1:
            print("torch.cuda.device_count()=",torch.cuda.device_count())
            model = torch.nn.DataParallel(model).cuda()  #multi-card data parallel
        else:
            print("single GPU for training")
            model = model.cuda()  #1-card data parallel
    start_epoch=0
    
    print("=====> Whether resuming from a checkpoint, for continuing training")
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            start_epoch = checkpoint["epoch"] 
            model.load_state_dict(checkpoint["model"])
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))


    model.train()
    cudnn.benchmark = True

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)
    
    print('=====> Computing network parameters')
    total_paramters = netParams(model)
    print('Total network parameters: ' + str(total_paramters))
 
    print("=====> Preparing training data")
    if args.dataset == 'voc12':
        trainloader = data.DataLoader(VOCDataSet(args.data_dir, args.data_list, max_iters=None, crop_size=input_size, 
                                                 scale=args.random_scale, mirror=args.random_mirror, mean=args.img_mean), 
                                      batch_size= args.batch_size, shuffle=True, num_workers=0, pin_memory=True, drop_last=True)
    elif args.dataset == 'cityscapes':
        trainloader = data.DataLoader(CityscapesDataSet(args.data_dir, args.data_list, max_iters=None, crop_size=input_size, 
                                                 scale=args.random_scale, mirror=args.random_mirror, mean=args.img_mean), 
                                      batch_size = args.batch_size, shuffle=True, num_workers=0, pin_memory=True, drop_last=True)
    elif args.dataset == 'davis':  #for davis 2016
        db_train = db.PairwiseImg(train=True, inputRes=input_size, db_root_dir=args.data_dir, img_root_dir=args.img_dir,  transform=None) #db_root_dir() --> '/path/to/DAVIS-2016' train path
        trainloader = data.DataLoader(db_train, batch_size= args.batch_size, shuffle=True, num_workers=0)
    else:
        print("dataset error")

    optimizer = optim.SGD([{'params': get_1x_lr_params(model), 'lr': 1*args.learning_rate },  #针对特定层进行学习,有些层不学习
                {'params': get_10x_lr_params(model), 'lr': 10*args.learning_rate}], 
                lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay)
    optimizer.zero_grad()


    
    logFileLoc = args.snapshot_dir + args.logFile
    if os.path.isfile(logFileLoc):
        logger = open(logFileLoc, 'a')
    else:
        logger = open(logFileLoc, 'w')
        logger.write("Parameters: %s" % (str(total_paramters)))
        logger.write("\n%s\t\t%s" % ('iter', 'Loss(train)\n'))
    logger.flush()

    print("=====> Begin to train")
    train_len=len(trainloader)
    print("  iteration numbers  of per epoch: ", train_len)
    print("  epoch num: ", args.maxEpoches)
    print("  max iteration: ", args.maxEpoches*train_len)
    
    for epoch in range(start_epoch, int(args.maxEpoches)):
        
        np.random.seed(args.random_seed + epoch)
        for i_iter, batch in enumerate(trainloader,0): #i_iter from 0 to len-1
            #print("i_iter=", i_iter, "epoch=", epoch)
            target, target_gt, search, search_gt = batch['target'], batch['target_grt'], batch['search'], batch['search_grt']
            images, labels = batch['img'], batch['img_grt']
            #print('input size:', len(target), target.size(),labels.size())
            #8,2,3,473,473
            images.requires_grad_()
            images = Variable(images).cuda()
            labels = Variable(labels.float().unsqueeze(1)).cuda()
            
            target.requires_grad_()
            target = Variable(target).cuda()
            target_gt = Variable(target_gt.float().unsqueeze(1)).cuda()
            
            search.requires_grad_()
            search = Variable(search).cuda()
            search_gt = Variable(search_gt.float().unsqueeze(1)).cuda()
            
            optimizer.zero_grad()
            
            lr = adjust_learning_rate(optimizer, i_iter+epoch*train_len, epoch,
                    max_iter = args.maxEpoches * train_len)
            #print(images.size())
            if i_iter%3 ==0: #对于静态图片的训练
                
                pred1, pred2  = model(images, images)
                loss = 0.1*(loss_calc1(pred2, labels) + 0.8* loss_calc2(pred2, labels))
                loss.backward()
                
            else:
                    
                pred1, pred2 = model(target, search)
                #print('video prediction size:', pred2.size(),target_gt.size())
                loss = loss_calc1(pred1, target_gt) + 0.8* loss_calc2(pred1, target_gt)
                loss.backward()
            
            optimizer.step()
                
            print("===> Epoch[{}]({}/{}): Loss: {:.10f}  lr: {:.5f}".format(epoch, i_iter, train_len, loss.data, lr))
            logger.write("Epoch[{}]({}/{}):     Loss: {:.10f}      lr: {:.5f}\n".format(epoch, i_iter, train_len, loss.data, lr))
            logger.flush()
                
        print("=====> saving model")
        state={"epoch": epoch+1, "model": model.state_dict()}
        torch.save(state, osp.join(args.snapshot_dir, 'co_attention_'+str(args.dataset)+"_"+str(epoch)+'.pth'))


    end = timeit.default_timer()
    print( float(end-start)/3600, 'h')
    logger.write("total training time: {:.2f} h\n".format(float(end-start)/3600))
    logger.close()


if __name__ == '__main__':
    main()
Download .txt
gitextract_zvfyan4m/

├── README.md
├── dataloaders/
│   ├── PairwiseImg_test.py
│   ├── PairwiseImg_video.py
│   ├── PairwiseImg_video_test_try.py
│   ├── PairwiseImg_video_try.py
│   └── r
├── deeplab/
│   ├── __init__.py
│   ├── e
│   ├── siamese_model_conf.py
│   ├── siamese_model_conf_try_single.py
│   └── utils.py
├── densecrf_apply_cvpr2019.py
├── pretrained/
│   └── deep_labv3/
│       └── readme.md
├── test_coattention_conf.py
├── test_iteration_conf_group.py
├── train_iteration_conf.py
└── train_iteration_conf_group.py
Download .txt
SYMBOL INDEX (119 symbols across 12 files)

FILE: dataloaders/PairwiseImg_test.py
  function flip (line 20) | def flip(I,flip_p):
  function scale_im (line 26) | def scale_im(img_temp,scale):
  function scale_gt (line 30) | def scale_gt(img_temp,scale):
  function my_crop (line 34) | def my_crop(img,gt):
  class PairwiseImg (line 46) | class PairwiseImg(Dataset):
    method __init__ (line 49) | def __init__(self, train=True,
    method __len__ (line 106) | def __len__(self):
    method __getitem__ (line 109) | def __getitem__(self, idx):
    method make_img_gt_pair (line 141) | def make_img_gt_pair(self, idx): #这个函数存在的意义是为了getitem函数服务的
    method get_img_size (line 186) | def get_img_size(self):

FILE: dataloaders/PairwiseImg_video.py
  function flip (line 20) | def flip(I,flip_p):
  function scale_im (line 26) | def scale_im(img_temp,scale):
  function scale_gt (line 30) | def scale_gt(img_temp,scale):
  function my_crop (line 34) | def my_crop(img,gt):
  class PairwiseImg (line 46) | class PairwiseImg(Dataset):
    method __init__ (line 49) | def __init__(self, train=True,
    method __len__ (line 125) | def __len__(self):
    method __getitem__ (line 129) | def __getitem__(self, idx):
    method make_video_gt_pair (line 163) | def make_video_gt_pair(self, idx): #这个函数存在的意义是为了getitem函数服务的
    method get_img_size (line 207) | def get_img_size(self):
    method make_img_gt_pair (line 212) | def make_img_gt_pair(self, idx): #这个函数存在的意义是为了getitem函数服务的

FILE: dataloaders/PairwiseImg_video_test_try.py
  function flip (line 20) | def flip(I,flip_p):
  function scale_im (line 26) | def scale_im(img_temp,scale):
  function scale_gt (line 30) | def scale_gt(img_temp,scale):
  function my_crop (line 34) | def my_crop(img,gt):
  class PairwiseImg (line 46) | class PairwiseImg(Dataset):
    method __init__ (line 49) | def __init__(self, train=True,
    method __len__ (line 106) | def __len__(self):
    method __getitem__ (line 109) | def __getitem__(self, idx):
    method make_img_gt_pair (line 147) | def make_img_gt_pair(self, idx): #这个函数存在的意义是为了getitem函数服务的
    method get_img_size (line 192) | def get_img_size(self):

FILE: dataloaders/PairwiseImg_video_try.py
  function flip (line 20) | def flip(I,flip_p):
  function scale_im (line 26) | def scale_im(img_temp,scale):
  function scale_gt (line 30) | def scale_gt(img_temp,scale):
  function my_crop (line 34) | def my_crop(img,gt):
  class PairwiseImg (line 46) | class PairwiseImg(Dataset):
    method __init__ (line 49) | def __init__(self, train=True,
    method __len__ (line 125) | def __len__(self):
    method __getitem__ (line 129) | def __getitem__(self, idx):
    method make_video_gt_pair (line 174) | def make_video_gt_pair(self, idx): #这个函数存在的意义是为了getitem函数服务的
    method get_img_size (line 218) | def get_img_size(self):
    method make_img_gt_pair (line 223) | def make_img_gt_pair(self, idx): #这个函数存在的意义是为了getitem函数服务的

FILE: deeplab/siamese_model_conf.py
  function conv3x3 (line 15) | def conv3x3(in_planes, out_planes, stride=1):
  class BasicBlock (line 21) | class BasicBlock(nn.Module):
    method __init__ (line 24) | def __init__(self, inplanes, planes, stride=1, downsample=None):
    method forward (line 34) | def forward(self, x):
  class Bottleneck (line 53) | class Bottleneck(nn.Module):
    method __init__ (line 56) | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=...
    method forward (line 70) | def forward(self, x):
  class ASPP (line 93) | class ASPP(nn.Module):
    method __init__ (line 94) | def __init__(self, dilation_series, padding_series, depth):
    method _make_stage_ (line 121) | def _make_stage_(self, dilation1, padding1):
    method forward (line 128) | def forward(self, x):
  class ResNet (line 160) | class ResNet(nn.Module):
    method __init__ (line 161) | def __init__(self, block, layers, num_classes):
    method _make_layer (line 184) | def _make_layer(self, block, planes, blocks, stride=1, dilation=1):
    method _make_pred_layer (line 201) | def _make_pred_layer(self, block, dilation_series, padding_series, num...
    method forward (line 204) | def forward(self, x):
  class CoattentionModel (line 222) | class CoattentionModel(nn.Module):
    method __init__ (line 223) | def  __init__(self, block, layers, num_classes, all_channel=256, all_d...
    method forward (line 252) | def forward(self, input1, input2): #注意input2 可以是多帧图像
  function Res_Deeplab (line 299) | def Res_Deeplab(num_classes=2):
  function CoattentionNet (line 303) | def CoattentionNet(num_classes=2):

FILE: deeplab/siamese_model_conf_try_single.py
  function conv3x3 (line 16) | def conv3x3(in_planes, out_planes, stride=1):
  class BasicBlock (line 22) | class BasicBlock(nn.Module):
    method __init__ (line 25) | def __init__(self, inplanes, planes, stride=1, downsample=None):
    method forward (line 35) | def forward(self, x):
  class Bottleneck (line 54) | class Bottleneck(nn.Module):
    method __init__ (line 57) | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=...
    method forward (line 71) | def forward(self, x):
  class ASPP (line 94) | class ASPP(nn.Module):
    method __init__ (line 95) | def __init__(self, dilation_series, padding_series, depth):
    method _make_stage_ (line 122) | def _make_stage_(self, dilation1, padding1):
    method forward (line 129) | def forward(self, x):
  class ResNet (line 161) | class ResNet(nn.Module):
    method __init__ (line 162) | def __init__(self, block, layers, num_classes):
    method _make_layer (line 185) | def _make_layer(self, block, planes, blocks, stride=1, dilation=1):
    method _make_pred_layer (line 202) | def _make_pred_layer(self, block, dilation_series, padding_series, num...
    method forward (line 205) | def forward(self, x):
  class CoattentionModel (line 223) | class CoattentionModel(nn.Module):
    method __init__ (line 224) | def  __init__(self, block, layers, num_classes,  all_channel=256, all_...
    method forward (line 254) | def forward(self, input1, input2): #注意input2 可以是多帧图像
  function CoattentionNet (line 309) | def CoattentionNet(num_classes=2,nframes=2):

FILE: deeplab/utils.py
  function loss_calc (line 6) | def loss_calc(pred, label, ignore_label):
  function lr_poly (line 18) | def lr_poly(base_lr, iter, max_iter, power):
  function get_1x_lr_params (line 22) | def get_1x_lr_params(model):
  function get_10x_lr_params (line 40) | def get_10x_lr_params(model):
  function adjust_learning_rate (line 52) | def adjust_learning_rate(optimizer, i_iter, learning_rate, num_steps, po...

FILE: densecrf_apply_cvpr2019.py
  function worker (line 22) | def worker(scale, g_dim, g_factor,s_dim,C_dim,c_factor):

FILE: test_coattention_conf.py
  function get_arguments (line 37) | def get_arguments():
  function configure_dataset_model (line 57) | def configure_dataset_model(args):
  function convert_state_dict (line 89) | def convert_state_dict(state_dict):
  function sigmoid (line 108) | def sigmoid(inX):
  function main (line 111) | def main():

FILE: test_iteration_conf_group.py
  function get_arguments (line 37) | def get_arguments():
  function configure_dataset_model (line 57) | def configure_dataset_model(args):
  function convert_state_dict (line 74) | def convert_state_dict(state_dict):
  function sigmoid (line 93) | def sigmoid(inX):
  function main (line 96) | def main():

FILE: train_iteration_conf.py
  function get_arguments (line 35) | def get_arguments():
  function configure_dataset_init_model (line 78) | def configure_dataset_init_model(args):
  function adjust_learning_rate (line 133) | def adjust_learning_rate(optimizer, i_iter, epoch, max_iter):
  function loss_calc1 (line 147) | def loss_calc1(pred, label):
  function loss_calc2 (line 171) | def loss_calc2(pred, label):
  function get_1x_lr_params (line 184) | def get_1x_lr_params(model):
  function get_10x_lr_params (line 218) | def get_10x_lr_params(model):
  function lr_poly (line 242) | def lr_poly(base_lr, iter, max_iter, power, epoch):
  function netParams (line 252) | def netParams(model):
  function main (line 270) | def main():

FILE: train_iteration_conf_group.py
  function get_arguments (line 35) | def get_arguments():
  function configure_dataset_init_model (line 78) | def configure_dataset_init_model(args):
  function adjust_learning_rate (line 133) | def adjust_learning_rate(optimizer, i_iter, epoch, max_iter):
  function loss_calc1 (line 147) | def loss_calc1(pred, label):
  function loss_calc2 (line 171) | def loss_calc2(pred, label):
  function get_1x_lr_params (line 184) | def get_1x_lr_params(model):
  function get_10x_lr_params (line 218) | def get_10x_lr_params(model):
  function lr_poly (line 242) | def lr_poly(base_lr, iter, max_iter, power, epoch):
  function netParams (line 252) | def netParams(model):
  function main (line 270) | def main():
Condensed preview — 17 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (144K chars).
[
  {
    "path": "README.md",
    "chars": 3670,
    "preview": "# COSNet\nCode for CVPR 2019 paper: \n\n[See More, Know More: Unsupervised Video Object Segmentation with\nCo-Attention Siam"
  },
  {
    "path": "dataloaders/PairwiseImg_test.py",
    "chars": 8404,
    "preview": "# -*- coding: utf-8 -*-\r\n\"\"\"\r\nCreated on Wed Sep 12 11:39:54 2018\r\n\r\n@author: carri\r\n\"\"\"\r\n# for testing case\r\nfrom __fut"
  },
  {
    "path": "dataloaders/PairwiseImg_video.py",
    "chars": 10669,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"\nCreated on Wed Sep 12 11:39:54 2018\n\n@author: carri\n\"\"\"\n\nfrom __future__ import division\n\nim"
  },
  {
    "path": "dataloaders/PairwiseImg_video_test_try.py",
    "chars": 8760,
    "preview": "# -*- coding: utf-8 -*-\r\n\"\"\"\r\nCreated on Wed Sep 12 11:39:54 2018\r\n\r\n@author: carri\r\n\"\"\"\r\n# for testing case\r\nfrom __fut"
  },
  {
    "path": "dataloaders/PairwiseImg_video_try.py",
    "chars": 11459,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"\nCreated on Wed Sep 12 11:39:54 2018\n\n@author: carri\n\"\"\"\n\nfrom __future__ import division\n\nim"
  },
  {
    "path": "dataloaders/r",
    "chars": 1,
    "preview": "\n"
  },
  {
    "path": "deeplab/__init__.py",
    "chars": 1,
    "preview": "\n"
  },
  {
    "path": "deeplab/e",
    "chars": 1,
    "preview": "\n"
  },
  {
    "path": "deeplab/siamese_model_conf.py",
    "chars": 12813,
    "preview": "# -*- coding: utf-8 -*-\r\n\"\"\"\r\nCreated on Sun Sep 16 10:01:14 2018\r\n\r\n@author: carri\r\n\"\"\"\r\n\r\nimport torch.nn as nn\r\nimpor"
  },
  {
    "path": "deeplab/siamese_model_conf_try_single.py",
    "chars": 13459,
    "preview": "# -*- coding: utf-8 -*-\r\n\"\"\"\r\nCreated on Sun Sep 16 10:01:14 2018\r\n\r\n@author: carri\r\n\"\"\"\r\n\r\nimport torch.nn as nn\r\nimpor"
  },
  {
    "path": "deeplab/utils.py",
    "chars": 1920,
    "preview": "import torch\n#from tensorboard_logger import log_value\nfrom torch.autograd import Variable\n\n\ndef loss_calc(pred, label, "
  },
  {
    "path": "densecrf_apply_cvpr2019.py",
    "chars": 4956,
    "preview": "#!/usr/bin/env python3\n# -*- coding: utf-8 -*-\n\"\"\"\nCreated on Fri Mar  1 20:37:37 2019\n\n@author: xiankai\n\"\"\"\n\nimport pyd"
  },
  {
    "path": "pretrained/deep_labv3/readme.md",
    "chars": 1,
    "preview": "\n"
  },
  {
    "path": "test_coattention_conf.py",
    "chars": 9626,
    "preview": "# -*- coding: utf-8 -*-\r\n\"\"\"\r\nCreated on Mon Sep 17 17:53:20 2018\r\n\r\n@author: carri\r\n\"\"\"\r\n\r\nimport argparse\r\nimport torc"
  },
  {
    "path": "test_iteration_conf_group.py",
    "chars": 8220,
    "preview": "# -*- coding: utf-8 -*-\r\n\"\"\"\r\nCreated on Mon Sep 17 17:53:20 2018\r\n\r\n@author: carri\r\n\"\"\"\r\n\r\nimport argparse\r\nimport torc"
  },
  {
    "path": "train_iteration_conf.py",
    "chars": 20496,
    "preview": "# -*- coding: utf-8 -*-\r\n\"\"\"\r\nCreated on Sat Sep 15 10:52:26 2018\r\n\r\n@author: carri\r\n\"\"\"\r\n#区别于deeplab_co_attention_conca"
  },
  {
    "path": "train_iteration_conf_group.py",
    "chars": 20513,
    "preview": "# -*- coding: utf-8 -*-\r\n\"\"\"\r\nCreated on Sat Sep 15 10:52:26 2018\r\n\r\n@author: carri\r\n\"\"\"\r\n#区别于deeplab_co_attention_conca"
  }
]

About this extraction

This page contains the full source code of the carrierlxk/COSNet GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 17 files (131.8 KB), approximately 34.6k tokens, and a symbol index with 119 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!