[
  {
    "path": "README.md",
    "content": "# COSNet\nCode for CVPR 2019 paper: \n\n[See More, Know More: Unsupervised Video Object Segmentation with\nCo-Attention Siamese Networks](http://openaccess.thecvf.com/content_CVPR_2019/papers/Lu_See_More_Know_More_Unsupervised_Video_Object_Segmentation_With_Co-Attention_CVPR_2019_paper.pdf)\n\n[Xiankai Lu](https://sites.google.com/site/xiankailu111/), [Wenguan Wang](https://sites.google.com/view/wenguanwang), Chao Ma, Jianbing Shen, Ling Shao, Fatih Porikli\n\n##\n\n![](../master/framework.png)\n\n- - -\n:new:\n\nOur group co-attention achieves a further performance gain (81.1 mean J on DAVIS-16 dataset), related codes have also been released.\n\nThe pre-trained model, testing and training code:\n\n### Quick Start\n\n#### Testing\n\n1. Install pytorch (version:1.0.1).\n\n2. Download the pretrained model. Run 'test_coattention_conf.py' and change the davis dataset path, pretrainde model path and result path.\n\n3. Run command: python test_coattention_conf.py --dataset davis --gpus 0\n\n4. Post CRF processing code comes from: https://github.com/lucasb-eyer/pydensecrf. \n\nThe pretrained weight can be download from [GoogleDrive](https://drive.google.com/open?id=14ya3ZkneeHsegCgDrvkuFtGoAfVRgErz) or [BaiduPan](https://pan.baidu.com/s/16oFzRmn4Meuq83fCYr4boQ), pass code: xwup.\n\nThe segmentation results on DAVIS, FBMS and Youtube-objects can be download from DAVIS_benchmark(https://davischallenge.org/davis2016/soa_compare.html) or\n[GoogleDrive](https://drive.google.com/open?id=1JRPc2kZmzx0b7WLjxTPD-kdgFdXh5gBq) or [BaiduPan](https://pan.baidu.com/s/11n7zAt3Lo2P3-42M2lsw6Q), pass code: q37f.\n\nThe youtube-objects dataset can be downloaded from [here](http://calvin-vision.net/datasets/youtube-objects-dataset/) and annotation can be found [here](http://vision.cs.utexas.edu/projects/videoseg/data_download_register.html).\n\nThe FBMS dataset can be downloaded from [here](https://lmb.informatik.uni-freiburg.de/resources/datasets/moseg.en.html).\n#### Training\n\n1. Download all the training datasets, including MARA10K and DUT saliency datasets. Create a folder called images and put these two datasets into the folder. \n\n2. Download the deeplabv3 model from [GoogleDrive](https://drive.google.com/open?id=1hy0-BAEestT9H4a3Sv78xrHrzmZga9mj). Put it into the folder pretrained/deep_labv3.\n\n3. Change the video path, image path and deeplabv3 path in train_iteration_conf.py.  Create two txt files which store the saliency dataset name and DAVIS16 training sequences name. Change the txt path in PairwiseImg_video.py.\n\n4. Run command: python train_iteration_conf.py --dataset davis --gpus 0,1\n\n### Citation\n\nIf you find the code and dataset useful in your research, please consider citing:\n```\n@InProceedings{Lu_2019_CVPR,  \nauthor = {Lu, Xiankai and Wang, Wenguan and Ma, Chao and Shen, Jianbing and Shao, Ling and Porikli, Fatih},  \ntitle = {See More, Know More: Unsupervised Video Object Segmentation With Co-Attention Siamese Networks},  \nbooktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},  \nyear = {2019}  \n}\n@article{lu2020_pami,\n  title={Zero-Shot Video Object Segmentation with Co-Attention Siamese Networks},\n  author={Lu, Xiankai and Wang, Wenguan and Shen, Jianbing and Crandall, David and Luo, Jiebo},\n  journal={IEEE Transactions on Pattern Analysis and Machine Intelligence},\n  year={2020},\n  publisher={IEEE}\n}\n```\n### Other related projects/papers:\n[Saliency-Aware Geodesic Video Object Segmentation (CVPR15)](https://github.com/wenguanwang/saliencysegment)\n\n[Learning Unsupervised Video Primary Object Segmentation through Visual Attention (CVPR19)](https://github.com/wenguanwang/AGS)\n\nAny comments, please email: carrierlxk@gmail.com\n"
  },
  {
    "path": "dataloaders/PairwiseImg_test.py",
    "content": "# -*- coding: utf-8 -*-\r\n\"\"\"\r\nCreated on Wed Sep 12 11:39:54 2018\r\n\r\n@author: carri\r\n\"\"\"\r\n# for testing case\r\nfrom __future__ import division\r\n\r\nimport os\r\nimport numpy as np\r\nimport cv2\r\nfrom scipy.misc import imresize\r\nimport scipy.misc \r\nimport random\r\n\r\n#from dataloaders.helpers import *\r\nfrom torch.utils.data import Dataset\r\n\r\ndef flip(I,flip_p):\r\n    if flip_p>0.5:\r\n        return np.fliplr(I)\r\n    else:\r\n        return I\r\n\r\ndef scale_im(img_temp,scale):\r\n    new_dims = (  int(img_temp.shape[0]*scale),  int(img_temp.shape[1]*scale)   )\r\n    return cv2.resize(img_temp,new_dims).astype(float)\r\n\r\ndef scale_gt(img_temp,scale):\r\n    new_dims = (  int(img_temp.shape[0]*scale),  int(img_temp.shape[1]*scale)   )\r\n    return cv2.resize(img_temp,new_dims,interpolation = cv2.INTER_NEAREST).astype(float)\r\n\r\ndef my_crop(img,gt):\r\n    H = int(0.9 * img.shape[0])\r\n    W = int(0.9 * img.shape[1])\r\n    H_offset = random.choice(range(img.shape[0] - H))\r\n    W_offset = random.choice(range(img.shape[1] - W))\r\n    H_slice = slice(H_offset, H_offset + H)\r\n    W_slice = slice(W_offset, W_offset + W)\r\n    img = img[H_slice, W_slice, :]\r\n    gt = gt[H_slice, W_slice]\r\n    \r\n    return img, gt\r\n\r\nclass PairwiseImg(Dataset):\r\n    \"\"\"DAVIS 2016 dataset constructed using the PyTorch built-in functionalities\"\"\"\r\n\r\n    def __init__(self, train=True,\r\n                 inputRes=None,\r\n                 db_root_dir='/DAVIS-2016',\r\n                 transform=None,\r\n                 meanval=(104.00699, 116.66877, 122.67892),\r\n                 seq_name=None, sample_range=10):\r\n        \"\"\"Loads image to label pairs for tool pose estimation\r\n        db_root_dir: dataset directory with subfolders \"JPEGImages\" and \"Annotations\"\r\n        \"\"\"\r\n        self.train = train\r\n        self.range = sample_range\r\n        self.inputRes = inputRes\r\n        self.db_root_dir = db_root_dir\r\n        self.transform = transform\r\n        self.meanval = meanval\r\n        self.seq_name = seq_name\r\n\r\n        if self.train:\r\n            fname = 'train_seqs'\r\n        else:\r\n            fname = 'val_seqs'\r\n\r\n        if self.seq_name is None: #所有的数据集都参与训练\r\n            with open(os.path.join(db_root_dir, fname + '.txt')) as f:\r\n                seqs = f.readlines()\r\n                img_list = []\r\n                labels = []\r\n                Index = {}\r\n                for seq in seqs:                    \r\n                    images = np.sort(os.listdir(os.path.join(db_root_dir, 'JPEGImages/480p/', seq.strip('\\n'))))\r\n                    images_path = list(map(lambda x: os.path.join('JPEGImages/480p/', seq.strip(), x), images))\r\n                    start_num = len(img_list)\r\n                    img_list.extend(images_path)\r\n                    end_num = len(img_list)\r\n                    Index[seq.strip('\\n')]= np.array([start_num, end_num])\r\n                    lab = np.sort(os.listdir(os.path.join(db_root_dir, 'Annotations/480p/', seq.strip('\\n'))))\r\n                    lab_path = list(map(lambda x: os.path.join('Annotations/480p/', seq.strip(), x), lab))\r\n                    labels.extend(lab_path)\r\n        else: #针对所有的训练样本， img_list存放的是图片的路径\r\n\r\n            # Initialize the per sequence images for online training\r\n            names_img = np.sort(os.listdir(os.path.join(db_root_dir, str(seq_name))))\r\n            img_list = list(map(lambda x: os.path.join(( str(seq_name)), x), names_img))\r\n            #name_label = np.sort(os.listdir(os.path.join(db_root_dir,  str(seq_name))))\r\n            labels = [os.path.join( (str(seq_name)+'/saliencymaps'), names_img[0])]\r\n            labels.extend([None]*(len(names_img)-1)) #在labels这个列表后面添加元素None\r\n            if self.train:\r\n                img_list = [img_list[0]]\r\n                labels = [labels[0]]\r\n\r\n        assert (len(labels) == len(img_list))\r\n\r\n        self.img_list = img_list\r\n        self.labels = labels\r\n        self.Index = Index\r\n        #img_files = open('all_im.txt','w+')\r\n\r\n    def __len__(self):\r\n        return len(self.img_list)\r\n\r\n    def __getitem__(self, idx):\r\n        target, target_gt,sequence_name = self.make_img_gt_pair(idx) #测试时候要分割的帧\r\n        target_id = idx\r\n        seq_name1 = self.img_list[target_id].split('/')[-2] #获取视频名称\r\n        sample = {'target': target, 'target_gt': target_gt, 'seq_name': sequence_name, 'search_0': None}\r\n        if self.range>=1:\r\n            my_index = self.Index[seq_name1]\r\n            search_num = list(range(my_index[0], my_index[1]))  \r\n            search_ids = random.sample(search_num, self.range)#min(len(self.img_list)-1, target_id+np.random.randint(1,self.range+1))\r\n        \r\n            for i in range(0,self.range):\r\n                search_id = search_ids[i]\r\n                search, search_gt,sequence_name = self.make_img_gt_pair(search_id)\r\n                if sample['search_0'] is None:\r\n                    sample['search_0'] = search\r\n                else:\r\n                    sample['search'+'_'+str(i)] = search\r\n            #np.save('search1.npy',search)\r\n            #np.save('search_gt.npy',search_gt)\r\n            if self.seq_name is not None:\r\n                fname = os.path.join(self.seq_name, \"%05d\" % idx)\r\n                sample['fname'] = fname\r\n       \r\n        else:\r\n            img, gt = self.make_img_gt_pair(idx)\r\n            sample = {'image': img, 'gt': gt}\r\n            if self.seq_name is not None:\r\n                fname = os.path.join(self.seq_name, \"%05d\" % idx)\r\n                sample['fname'] = fname\r\n\r\n        return sample  #这个类最后的输出\r\n\r\n    def make_img_gt_pair(self, idx): #这个函数存在的意义是为了getitem函数服务的\r\n        \"\"\"\r\n        Make the image-ground-truth pair\r\n        \"\"\"\r\n        img = cv2.imread(os.path.join(self.db_root_dir, self.img_list[idx]), cv2.IMREAD_COLOR)\r\n        if self.labels[idx] is not None and self.train:\r\n            label = cv2.imread(os.path.join(self.db_root_dir, self.labels[idx]), cv2.IMREAD_GRAYSCALE)\r\n            #print(os.path.join(self.db_root_dir, self.labels[idx]))\r\n        else:\r\n            gt = np.zeros(img.shape[:-1], dtype=np.uint8)\r\n            \r\n         ## 已经读取了image以及对应的ground truth可以进行data augmentation了\r\n        if self.train:  #scaling, cropping and flipping\r\n             img, label = my_crop(img,label)\r\n             scale = random.uniform(0.7, 1.3)\r\n             flip_p = random.uniform(0, 1)\r\n             img_temp = scale_im(img,scale)\r\n             img_temp = flip(img_temp,flip_p)\r\n             gt_temp = scale_gt(label,scale)\r\n             gt_temp = flip(gt_temp,flip_p)\r\n             \r\n             img = img_temp\r\n             label = gt_temp\r\n             \r\n        if self.inputRes is not None:\r\n            img = imresize(img, self.inputRes)\r\n            #print('ok1')\r\n            #scipy.misc.imsave('label.png',label)\r\n            #scipy.misc.imsave('img.png',img)\r\n            if self.labels[idx] is not None and self.train:\r\n                label = imresize(label, self.inputRes, interp='nearest')\r\n\r\n        img = np.array(img, dtype=np.float32)\r\n        #img = img[:, :, ::-1]\r\n        img = np.subtract(img, np.array(self.meanval, dtype=np.float32))        \r\n        img = img.transpose((2, 0, 1))  # NHWC -> NCHW\r\n        \r\n        if self.labels[idx] is not None and self.train:\r\n                gt = np.array(label, dtype=np.int32)\r\n                gt[gt!=0]=1\r\n                #gt = gt/np.max([gt.max(), 1e-8])\r\n        #np.save('gt.npy')\r\n        sequence_name = self.img_list[idx].split('/')[2]\r\n        return img, gt, sequence_name \r\n\r\n    def get_img_size(self):\r\n        img = cv2.imread(os.path.join(self.db_root_dir, self.img_list[0]))\r\n        \r\n        return list(img.shape[:2])\r\n\r\n\r\nif __name__ == '__main__':\r\n    import custom_transforms as tr\r\n    import torch\r\n    from torchvision import transforms\r\n    from matplotlib import pyplot as plt\r\n\r\n    transforms = transforms.Compose([tr.RandomHorizontalFlip(), tr.Resize(scales=[0.5, 0.8, 1]), tr.ToTensor()])\r\n\r\n    #dataset = DAVIS2016(db_root_dir='/media/eec/external/Databases/Segmentation/DAVIS-2016',\r\n                       # train=True, transform=transforms)\r\n    #dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=1)\r\n#\r\n#    for i, data in enumerate(dataloader):\r\n#        plt.figure()\r\n#        plt.imshow(overlay_mask(im_normalize(tens2image(data['image'])), tens2image(data['gt'])))\r\n#        if i == 10:\r\n#            break\r\n#\r\n#    plt.show(block=True)\r\n"
  },
  {
    "path": "dataloaders/PairwiseImg_video.py",
    "content": "# -*- coding: utf-8 -*-\n\"\"\"\nCreated on Wed Sep 12 11:39:54 2018\n\n@author: carri\n\"\"\"\n\nfrom __future__ import division\n\nimport os\nimport numpy as np\nimport cv2\nfrom scipy.misc import imresize\nimport scipy.misc \nimport random\n\nfrom dataloaders.helpers import *\nfrom torch.utils.data import Dataset\n\ndef flip(I,flip_p):\n    if flip_p>0.5:\n        return np.fliplr(I)\n    else:\n        return I\n\ndef scale_im(img_temp,scale):\n    new_dims = (  int(img_temp.shape[0]*scale),  int(img_temp.shape[1]*scale)   )\n    return cv2.resize(img_temp,new_dims).astype(float)\n\ndef scale_gt(img_temp,scale):\n    new_dims = (  int(img_temp.shape[0]*scale),  int(img_temp.shape[1]*scale)   )\n    return cv2.resize(img_temp,new_dims,interpolation = cv2.INTER_NEAREST).astype(float)\n\ndef my_crop(img,gt):\n    H = int(0.9 * img.shape[0])\n    W = int(0.9 * img.shape[1])\n    H_offset = random.choice(range(img.shape[0] - H))\n    W_offset = random.choice(range(img.shape[1] - W))\n    H_slice = slice(H_offset, H_offset + H)\n    W_slice = slice(W_offset, W_offset + W)\n    img = img[H_slice, W_slice, :]\n    gt = gt[H_slice, W_slice]\n    \n    return img, gt\n\nclass PairwiseImg(Dataset):\n    \"\"\"DAVIS 2016 dataset constructed using the PyTorch built-in functionalities\"\"\"\n\n    def __init__(self, train=True,\n                 inputRes=None,\n                 db_root_dir='/DAVIS-2016',\n                 img_root_dir = None,\n                 transform=None,\n                 meanval=(104.00699, 116.66877, 122.67892),\n                 seq_name=None, sample_range=10):\n        \"\"\"Loads image to label pairs for tool pose estimation\n        db_root_dir: dataset directory with subfolders \"JPEGImages\" and \"Annotations\"\n        \"\"\"\n        self.train = train\n        self.range = sample_range\n        self.inputRes = inputRes\n        self.img_root_dir = img_root_dir\n        self.db_root_dir = db_root_dir\n        self.transform = transform\n        self.meanval = meanval\n        self.seq_name = seq_name\n\n        if self.train:\n            fname = 'train_seqs'\n        else:\n            fname = 'val_seqs'\n\n        if self.seq_name is None: #所有的数据集都参与训练\n            with open(os.path.join(db_root_dir, fname + '.txt')) as f:\n                seqs = f.readlines()\n                video_list = []\n                labels = []\n                Index = {}\n                image_list = []\n                im_label = []\n                for seq in seqs:                    \n                    images = np.sort(os.listdir(os.path.join(db_root_dir, 'JPEGImages/480p/', seq.strip('\\n'))))\n                    images_path = list(map(lambda x: os.path.join('JPEGImages/480p/', seq.strip(), x), images))\n                    start_num = len(video_list)\n                    video_list.extend(images_path)\n                    end_num = len(video_list)\n                    Index[seq.strip('\\n')]= np.array([start_num, end_num])\n                    lab = np.sort(os.listdir(os.path.join(db_root_dir, 'Annotations/480p/', seq.strip('\\n'))))\n                    lab_path = list(map(lambda x: os.path.join('Annotations/480p/', seq.strip(), x), lab))\n                    labels.extend(lab_path)\n                    \n                with open('/home/ubuntu/xiankai/saliency_data.txt') as f:\n                    seqs = f.readlines()\n                #data_list = np.sort(os.listdir(db_root_dir))\n                    for seq in seqs: #所有数据集\n                        seq = seq.strip('\\n') \n                        images = np.sort(os.listdir(os.path.join(img_root_dir,seq.strip())+'/images/'))#针对某个数据集，比如DUT\t\t\t\n            # Initialize the original DAVIS splits for training the parent network\n                        images_path = list(map(lambda x: os.path.join((seq +'/images'), x), images))         \n                        image_list.extend(images_path)\n                        lab = np.sort(os.listdir(os.path.join(img_root_dir,seq.strip())+'/saliencymaps'))\n                        lab_path = list(map(lambda x: os.path.join((seq +'/saliencymaps'),x), lab))\n                        im_label.extend(lab_path)\n        else: #针对所有的训练样本， video_list存放的是图片的路径\n\n            # Initialize the per sequence images for online training\n            names_img = np.sort(os.listdir(os.path.join(db_root_dir, str(seq_name))))\n            video_list = list(map(lambda x: os.path.join(( str(seq_name)), x), names_img))\n            #name_label = np.sort(os.listdir(os.path.join(db_root_dir,  str(seq_name))))\n            labels = [os.path.join( (str(seq_name)+'/saliencymaps'), names_img[0])]\n            labels.extend([None]*(len(names_img)-1)) #在labels这个列表后面添加元素None\n            if self.train:\n                video_list = [video_list[0]]\n                labels = [labels[0]]\n\n        assert (len(labels) == len(video_list))\n\n        self.video_list = video_list\n        self.labels = labels\n        self.image_list = image_list\n        self.img_labels = im_label\n        self.Index = Index\n        #img_files = open('all_im.txt','w+')\n\n    def __len__(self):\n        print(len(self.video_list), len(self.image_list))\n        return len(self.video_list)\n    \n    def __getitem__(self, idx):\n        target, target_gt = self.make_video_gt_pair(idx)\n        target_id = idx\n        img_idx = np.random.randint(1,len(self.image_list)-1)\n        seq_name1 = self.video_list[idx].split('/')[-2] #获取视频名称\n        if self.train:\n            my_index = self.Index[seq_name1]\n            search_id = np.random.randint(my_index[0], my_index[1])#min(len(self.video_list)-1, target_id+np.random.randint(1,self.range+1))\n            if search_id == target_id:\n                search_id = np.random.randint(my_index[0], my_index[1])\n            search, search_gt = self.make_video_gt_pair(search_id)\n            img, img_gt = self.make_img_gt_pair(img_idx)\n            sample = {'target': target, 'target_gt': target_gt, 'search': search, 'search_gt': search_gt, \\\n                      'img': img, 'img_gt': img_gt}\n            #np.save('search1.npy',search)\n            #np.save('search_gt.npy',search_gt)\n            if self.seq_name is not None:\n                fname = os.path.join(self.seq_name, \"%05d\" % idx)\n                sample['fname'] = fname\n\n            if self.transform is not None:\n                sample = self.transform(sample)\n       \n        else:\n            img, gt = self.make_video_gt_pair(idx)\n            sample = {'image': img, 'gt': gt}\n            if self.seq_name is not None:\n                fname = os.path.join(self.seq_name, \"%05d\" % idx)\n                sample['fname'] = fname\n        \n        \n        \n        return sample  #这个类最后的输出\n\n    def make_video_gt_pair(self, idx): #这个函数存在的意义是为了getitem函数服务的\n        \"\"\"\n        Make the image-ground-truth pair\n        \"\"\"\n        img = cv2.imread(os.path.join(self.db_root_dir, self.video_list[idx]), cv2.IMREAD_COLOR)\n        if self.labels[idx] is not None and self.train:\n            label = cv2.imread(os.path.join(self.db_root_dir, self.labels[idx]), cv2.IMREAD_GRAYSCALE)\n            #print(os.path.join(self.db_root_dir, self.labels[idx]))\n        else:\n            gt = np.zeros(img.shape[:-1], dtype=np.uint8)\n            \n         ## 已经读取了image以及对应的ground truth可以进行data augmentation了\n        if self.train:  #scaling, cropping and flipping\n             img, label = my_crop(img,label)\n             scale = random.uniform(0.7, 1.3)\n             flip_p = random.uniform(0, 1)\n             img_temp = scale_im(img,scale)\n             img_temp = flip(img_temp,flip_p)\n             gt_temp = scale_gt(label,scale)\n             gt_temp = flip(gt_temp,flip_p)\n             \n             img = img_temp\n             label = gt_temp\n             \n        if self.inputRes is not None:\n            img = imresize(img, self.inputRes)\n            #print('ok1')\n            #scipy.misc.imsave('label.png',label)\n            #scipy.misc.imsave('img.png',img)\n            if self.labels[idx] is not None and self.train:\n                label = imresize(label, self.inputRes, interp='nearest')\n\n        img = np.array(img, dtype=np.float32)\n        #img = img[:, :, ::-1]\n        img = np.subtract(img, np.array(self.meanval, dtype=np.float32))        \n        img = img.transpose((2, 0, 1))  # NHWC -> NCHW\n        \n        if self.labels[idx] is not None and self.train:\n                gt = np.array(label, dtype=np.int32)\n                gt[gt!=0]=1\n                #gt = gt/np.max([gt.max(), 1e-8])\n        #np.save('gt.npy')\n        return img, gt\n\n    def get_img_size(self):\n        img = cv2.imread(os.path.join(self.db_root_dir, self.video_list[0]))\n        \n        return list(img.shape[:2])\n\n    def make_img_gt_pair(self, idx): #这个函数存在的意义是为了getitem函数服务的\n        \"\"\"\n        Make the image-ground-truth pair\n        \"\"\"\n        img = cv2.imread(os.path.join(self.img_root_dir, self.image_list[idx]),cv2.IMREAD_COLOR)\n        #print(os.path.join(self.db_root_dir, self.img_list[idx]))\n        if self.img_labels[idx] is not None and self.train:\n            label = cv2.imread(os.path.join(self.img_root_dir, self.img_labels[idx]),cv2.IMREAD_GRAYSCALE)\n            #print(os.path.join(self.db_root_dir, self.labels[idx]))\n        else:\n            gt = np.zeros(img.shape[:-1], dtype=np.uint8)\n            \n        if self.inputRes is not None:            \n            img = imresize(img, self.inputRes)\n            if self.img_labels[idx] is not None and self.train:\n                label = imresize(label, self.inputRes, interp='nearest')\n\n        img = np.array(img, dtype=np.float32)\n        #img = img[:, :, ::-1]\n        img = np.subtract(img, np.array(self.meanval, dtype=np.float32))        \n        img = img.transpose((2, 0, 1))  # NHWC -> NCHW\n        \n        if self.img_labels[idx] is not None and self.train:\n                gt = np.array(label, dtype=np.int32)\n                gt[gt!=0]=1\n                #gt = gt/np.max([gt.max(), 1e-8])\n        #np.save('gt.npy')\n        return img, gt\n    \nif __name__ == '__main__':\n    import custom_transforms as tr\n    import torch\n    from torchvision import transforms\n    from matplotlib import pyplot as plt\n\n    transforms = transforms.Compose([tr.RandomHorizontalFlip(), tr.Resize(scales=[0.5, 0.8, 1]), tr.ToTensor()])\n\n    #dataset = DAVIS2016(db_root_dir='/media/eec/external/Databases/Segmentation/DAVIS-2016',\n                       # train=True, transform=transforms)\n    #dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=1)\n#\n#    for i, data in enumerate(dataloader):\n#        plt.figure()\n#        plt.imshow(overlay_mask(im_normalize(tens2image(data['image'])), tens2image(data['gt'])))\n#        if i == 10:\n#            break\n#\n#    plt.show(block=True)"
  },
  {
    "path": "dataloaders/PairwiseImg_video_test_try.py",
    "content": "# -*- coding: utf-8 -*-\r\n\"\"\"\r\nCreated on Wed Sep 12 11:39:54 2018\r\n\r\n@author: carri\r\n\"\"\"\r\n# for testing case\r\nfrom __future__ import division\r\n\r\nimport os\r\nimport numpy as np\r\nimport cv2\r\nfrom scipy.misc import imresize\r\nimport scipy.misc \r\nimport random\r\nimport torch\r\nfrom dataloaders.helpers import *\r\nfrom torch.utils.data import Dataset\r\n\r\ndef flip(I,flip_p):\r\n    if flip_p>0.5:\r\n        return np.fliplr(I)\r\n    else:\r\n        return I\r\n\r\ndef scale_im(img_temp,scale):\r\n    new_dims = (  int(img_temp.shape[0]*scale),  int(img_temp.shape[1]*scale)   )\r\n    return cv2.resize(img_temp,new_dims).astype(float)\r\n\r\ndef scale_gt(img_temp,scale):\r\n    new_dims = (  int(img_temp.shape[0]*scale),  int(img_temp.shape[1]*scale)   )\r\n    return cv2.resize(img_temp,new_dims,interpolation = cv2.INTER_NEAREST).astype(float)\r\n\r\ndef my_crop(img,gt):\r\n    H = int(0.9 * img.shape[0])\r\n    W = int(0.9 * img.shape[1])\r\n    H_offset = random.choice(range(img.shape[0] - H))\r\n    W_offset = random.choice(range(img.shape[1] - W))\r\n    H_slice = slice(H_offset, H_offset + H)\r\n    W_slice = slice(W_offset, W_offset + W)\r\n    img = img[H_slice, W_slice, :]\r\n    gt = gt[H_slice, W_slice]\r\n    \r\n    return img, gt\r\n\r\nclass PairwiseImg(Dataset):\r\n    \"\"\"DAVIS 2016 dataset constructed using the PyTorch built-in functionalities\"\"\"\r\n\r\n    def __init__(self, train=True,\r\n                 inputRes=None,\r\n                 db_root_dir='/DAVIS-2016',\r\n                 transform=None,\r\n                 meanval=(104.00699, 116.66877, 122.67892),\r\n                 seq_name=None, sample_range=10):\r\n        \"\"\"Loads image to label pairs for tool pose estimation\r\n        db_root_dir: dataset directory with subfolders \"JPEGImages\" and \"Annotations\"\r\n        \"\"\"\r\n        self.train = train\r\n        self.range = sample_range\r\n        self.inputRes = inputRes\r\n        self.db_root_dir = db_root_dir\r\n        self.transform = transform\r\n        self.meanval = meanval\r\n        self.seq_name = seq_name\r\n\r\n        if self.train:\r\n            fname = 'train_seqs'\r\n        else:\r\n            fname = 'val_seqs'\r\n\r\n        if self.seq_name is None: #所有的数据集都参与训练\r\n            with open(os.path.join(db_root_dir, fname + '.txt')) as f:\r\n                seqs = f.readlines()\r\n                img_list = []\r\n                labels = []\r\n                Index = {}\r\n                for seq in seqs:                    \r\n                    images = np.sort(os.listdir(os.path.join(db_root_dir, 'JPEGImages/480p/', seq.strip('\\n'))))\r\n                    images_path = list(map(lambda x: os.path.join('JPEGImages/480p/', seq.strip(), x), images))\r\n                    start_num = len(img_list)\r\n                    img_list.extend(images_path)\r\n                    end_num = len(img_list)\r\n                    Index[seq.strip('\\n')]= np.array([start_num, end_num])\r\n                    lab = np.sort(os.listdir(os.path.join(db_root_dir, 'Annotations/480p/', seq.strip('\\n'))))\r\n                    lab_path = list(map(lambda x: os.path.join('Annotations/480p/', seq.strip(), x), lab))\r\n                    labels.extend(lab_path)\r\n        else: #针对所有的训练样本， img_list存放的是图片的路径\r\n\r\n            # Initialize the per sequence images for online training\r\n            names_img = np.sort(os.listdir(os.path.join(db_root_dir, str(seq_name))))\r\n            img_list = list(map(lambda x: os.path.join(( str(seq_name)), x), names_img))\r\n            #name_label = np.sort(os.listdir(os.path.join(db_root_dir,  str(seq_name))))\r\n            labels = [os.path.join( (str(seq_name)+'/saliencymaps'), names_img[0])]\r\n            labels.extend([None]*(len(names_img)-1)) #在labels这个列表后面添加元素None\r\n            if self.train:\r\n                img_list = [img_list[0]]\r\n                labels = [labels[0]]\r\n\r\n        assert (len(labels) == len(img_list))\r\n\r\n        self.img_list = img_list\r\n        self.labels = labels\r\n        self.Index = Index\r\n        #img_files = open('all_im.txt','w+')\r\n\r\n    def __len__(self):\r\n        return len(self.img_list)\r\n\r\n    def __getitem__(self, idx):\r\n        target, target_grt,sequence_name = self.make_img_gt_pair(idx) #测试时候要分割的帧\r\n        target_id = idx\r\n        seq_name1 = self.img_list[target_id].split('/')[-2] #获取视频名称\r\n\r\n        #target_grts = torch.stack((torch.from_numpy(target_grt), torch.from_numpy(target_grt_1)))\r\n        #print('video name', seq_name1 )\r\n        sample = {'target': target, 'target_gt': target_grt, 'seq_name': sequence_name, 'search_0': None}\r\n        if self.range>=1:\r\n            my_index = self.Index[seq_name1]\r\n            search_num = list(range(my_index[0], my_index[1]))\r\n            search_ids = random.sample(search_num, self.range)#min(len(self.img_list)-1, target_id+np.random.randint(1,self.range+1))\r\n            searchs=[]\r\n            for i in range(0,self.range):\r\n\r\n                search_id = search_ids[i]\r\n                search, search_grt,sequence_name = self.make_img_gt_pair(search_id)\r\n                searchs.append(torch.from_numpy(search))\r\n                #search_grts = torch.stack((torch.from_numpy(search_grt), torch.from_numpy(search_grt_1)))\r\n            if sample['search_0'] is None:\r\n                sample['search_0'] = torch.stack(searchs,dim=0)\r\n            else:\r\n                sample['search'+'_'+str(i)] = torch.stack(searchs)\r\n            #np.save('search1.npy',search)\r\n            #np.save('search_gt.npy',search_gt)\r\n            if self.seq_name is not None:\r\n                fname = os.path.join(self.seq_name, \"%05d\" % idx)\r\n                sample['fname'] = fname\r\n       \r\n        else:\r\n            img, gt = self.make_img_gt_pair(idx)\r\n            sample = {'image': img, 'gt': gt}\r\n            if self.seq_name is not None:\r\n                fname = os.path.join(self.seq_name, \"%05d\" % idx)\r\n                sample['fname'] = fname\r\n\r\n        return sample  #这个类最后的输出\r\n\r\n    def make_img_gt_pair(self, idx): #这个函数存在的意义是为了getitem函数服务的\r\n        \"\"\"\r\n        Make the image-ground-truth pair\r\n        \"\"\"\r\n        img = cv2.imread(os.path.join(self.db_root_dir, self.img_list[idx]), cv2.IMREAD_COLOR)\r\n        if self.labels[idx] is not None and self.train:\r\n            label = cv2.imread(os.path.join(self.db_root_dir, self.labels[idx]), cv2.IMREAD_GRAYSCALE)\r\n            #print(os.path.join(self.db_root_dir, self.labels[idx]))\r\n        else:\r\n            gt = np.zeros(img.shape[:-1], dtype=np.uint8)\r\n            \r\n         ## 已经读取了image以及对应的ground truth可以进行data augmentation了\r\n        if self.train:  #scaling, cropping and flipping\r\n             img, label = my_crop(img,label)\r\n             scale = random.uniform(0.7, 1.3)\r\n             flip_p = random.uniform(0, 1)\r\n             img_temp = scale_im(img,scale)\r\n             img_temp = flip(img_temp,flip_p)\r\n             gt_temp = scale_gt(label,scale)\r\n             gt_temp = flip(gt_temp,flip_p)\r\n             \r\n             img = img_temp\r\n             label = gt_temp\r\n             \r\n        if self.inputRes is not None:\r\n            img = imresize(img, self.inputRes)\r\n            #print('ok1')\r\n            #scipy.misc.imsave('label.png',label)\r\n            #scipy.misc.imsave('img.png',img)\r\n            if self.labels[idx] is not None and self.train:\r\n                label = imresize(label, self.inputRes, interp='nearest')\r\n\r\n        img = np.array(img, dtype=np.float32)\r\n        #img = img[:, :, ::-1]\r\n        img = np.subtract(img, np.array(self.meanval, dtype=np.float32))        \r\n        img = img.transpose((2, 0, 1))  # NHWC -> NCHW\r\n        \r\n        if self.labels[idx] is not None and self.train:\r\n                gt = np.array(label, dtype=np.int32)\r\n                gt[gt!=0]=1\r\n                #gt = gt/np.max([gt.max(), 1e-8])\r\n        #np.save('gt.npy')\r\n        sequence_name = self.img_list[idx].split('/')[2]\r\n        return img, gt, sequence_name \r\n\r\n    def get_img_size(self):\r\n        img = cv2.imread(os.path.join(self.db_root_dir, self.img_list[0]))\r\n        \r\n        return list(img.shape[:2])\r\n\r\n\r\nif __name__ == '__main__':\r\n    import custom_transforms as tr\r\n    import torch\r\n    from torchvision import transforms\r\n    from matplotlib import pyplot as plt\r\n\r\n    transforms = transforms.Compose([tr.RandomHorizontalFlip(), tr.Resize(scales=[0.5, 0.8, 1]), tr.ToTensor()])\r\n\r\n    #dataset = DAVIS2016(db_root_dir='/media/eec/external/Databases/Segmentation/DAVIS-2016',\r\n                       # train=True, transform=transforms)\r\n    #dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=1)\r\n#\r\n#    for i, data in enumerate(dataloader):\r\n#        plt.figure()\r\n#        plt.imshow(overlay_mask(im_normalize(tens2image(data['image'])), tens2image(data['gt'])))\r\n#        if i == 10:\r\n#            break\r\n#\r\n#    plt.show(block=True)\r\n"
  },
  {
    "path": "dataloaders/PairwiseImg_video_try.py",
    "content": "# -*- coding: utf-8 -*-\n\"\"\"\nCreated on Wed Sep 12 11:39:54 2018\n\n@author: carri\n\"\"\"\n\nfrom __future__ import division\n\nimport os\nimport numpy as np\nimport cv2\nfrom scipy.misc import imresize\nimport scipy.misc \nimport random\nimport torch\nfrom dataloaders.helpers import *\nfrom torch.utils.data import Dataset\n\ndef flip(I,flip_p):\n    if flip_p>0.5:\n        return np.fliplr(I)\n    else:\n        return I\n\ndef scale_im(img_temp,scale):\n    new_dims = (  int(img_temp.shape[0]*scale),  int(img_temp.shape[1]*scale)   )\n    return cv2.resize(img_temp,new_dims).astype(float)\n\ndef scale_gt(img_temp,scale):\n    new_dims = (  int(img_temp.shape[0]*scale),  int(img_temp.shape[1]*scale)   )\n    return cv2.resize(img_temp,new_dims,interpolation = cv2.INTER_NEAREST).astype(float)\n\ndef my_crop(img,gt):\n    H = int(0.9 * img.shape[0])\n    W = int(0.9 * img.shape[1])\n    H_offset = random.choice(range(img.shape[0] - H))\n    W_offset = random.choice(range(img.shape[1] - W))\n    H_slice = slice(H_offset, H_offset + H)\n    W_slice = slice(W_offset, W_offset + W)\n    img = img[H_slice, W_slice, :]\n    gt = gt[H_slice, W_slice]\n    \n    return img, gt\n\nclass PairwiseImg(Dataset):\n    \"\"\"DAVIS 2016 dataset constructed using the PyTorch built-in functionalities\"\"\"\n\n    def __init__(self, train=True,\n                 inputRes=None,\n                 db_root_dir='/DAVIS-2016',\n                 img_root_dir = None,\n                 transform=None,\n                 meanval=(104.00699, 116.66877, 122.67892),\n                 seq_name=None, sample_range=10):\n        \"\"\"Loads image to label pairs for tool pose estimation\n        db_root_dir: dataset directory with subfolders \"JPEGImages\" and \"Annotations\"\n        \"\"\"\n        self.train = train\n        self.range = sample_range\n        self.inputRes = inputRes\n        self.img_root_dir = img_root_dir\n        self.db_root_dir = db_root_dir\n        self.transform = transform\n        self.meanval = meanval\n        self.seq_name = seq_name\n\n        if self.train:\n            fname = 'train_seqs'\n        else:\n            fname = 'val_seqs'\n\n        if self.seq_name is None: #所有的数据集都参与训练\n            with open(os.path.join(db_root_dir, fname + '.txt')) as f:\n                seqs = f.readlines()\n                video_list = []\n                labels = []\n                Index = {}\n                image_list = []\n                im_label = []\n                for seq in seqs:                    \n                    images = np.sort(os.listdir(os.path.join(db_root_dir, 'JPEGImages/480p/', seq.strip('\\n'))))\n                    images_path = list(map(lambda x: os.path.join('JPEGImages/480p/', seq.strip(), x), images))\n                    start_num = len(video_list)\n                    video_list.extend(images_path)\n                    end_num = len(video_list)\n                    Index[seq.strip('\\n')]= np.array([start_num, end_num])\n                    lab = np.sort(os.listdir(os.path.join(db_root_dir, 'Annotations/480p/', seq.strip('\\n'))))\n                    lab_path = list(map(lambda x: os.path.join('Annotations/480p/', seq.strip(), x), lab))\n                    labels.extend(lab_path)\n                    \n                with open('/home/ubuntu/xiankai/saliency_data.txt') as f:\n                    seqs = f.readlines()\n                #data_list = np.sort(os.listdir(db_root_dir))\n                    for seq in seqs: #所有数据集\n                        seq = seq.strip('\\n') \n                        images = np.sort(os.listdir(os.path.join(img_root_dir,seq.strip())+'/images/'))#针对某个数据集，比如DUT\t\t\t\n            # Initialize the original DAVIS splits for training the parent network\n                        images_path = list(map(lambda x: os.path.join((seq +'/images'), x), images))         \n                        image_list.extend(images_path)\n                        lab = np.sort(os.listdir(os.path.join(img_root_dir,seq.strip())+'/saliencymaps'))\n                        lab_path = list(map(lambda x: os.path.join((seq +'/saliencymaps'),x), lab))\n                        im_label.extend(lab_path)\n        else: #针对所有的训练样本， video_list存放的是图片的路径\n\n            # Initialize the per sequence images for online training\n            names_img = np.sort(os.listdir(os.path.join(db_root_dir, str(seq_name))))\n            video_list = list(map(lambda x: os.path.join(( str(seq_name)), x), names_img))\n            #name_label = np.sort(os.listdir(os.path.join(db_root_dir,  str(seq_name))))\n            labels = [os.path.join( (str(seq_name)+'/saliencymaps'), names_img[0])]\n            labels.extend([None]*(len(names_img)-1)) #在labels这个列表后面添加元素None\n            if self.train:\n                video_list = [video_list[0]]\n                labels = [labels[0]]\n\n        assert (len(labels) == len(video_list))\n\n        self.video_list = video_list\n        self.labels = labels\n        self.image_list = image_list\n        self.img_labels = im_label\n        self.Index = Index\n        #img_files = open('all_im.txt','w+')\n\n    def __len__(self):\n        print(len(self.video_list), len(self.image_list))\n        return len(self.video_list)\n    \n    def __getitem__(self, idx):\n        target, target_grt = self.make_video_gt_pair(idx)\n        target_id = idx\n        img_idx = random.sample([my_i for my_i in range(0,len(self.image_list))],2)\n\n        seq_name1 = self.video_list[idx].split('/')[-2] #获取视频名称\n        my_index = self.Index[seq_name1]\n        video_idx = random.sample([my_i for my_i in range(my_index[0],my_index[1])],3)\n        target_1, target_grt_1 = self.make_video_gt_pair(video_idx[0])\n        #print('type:', type(target))\n\n        #targets = torch.stack((torch.from_numpy(target),torch.from_numpy(target_1)))\n        #target_grts = torch.stack((torch.from_numpy(target_grt),torch.from_numpy(target_grt_1)))\n        #print('size:', torch.from_numpy(target_grt).size(), torch.from_numpy(target_grt_1).size())\n        if self.train:\n            #my_index = self.Index[seq_name1]\n            search, search_grt = self.make_video_gt_pair(video_idx[1])\n            search_1, search_grt_1 = self.make_video_gt_pair(video_idx[2])\n            searchs = torch.stack((torch.from_numpy(search), torch.from_numpy(search_1)))\n            search_grts = torch.stack((torch.from_numpy(search_grt), torch.from_numpy(search_grt_1)))\n            img, img_grt = self.make_img_gt_pair(img_idx[0])\n            #img_1, img_grt_1 = self.make_img_gt_pair(img_idx[1])\n            #imgs = torch.stack((torch.from_numpy(img), torch.from_numpy(img_1)))\n            #img_grts = torch.stack((torch.torch.from_numpy(img_grt), torch.from_numpy(img_grt_1)))\n            sample = {'target': target, 'target_grt': target_grt, 'search': searchs, 'search_grt': search_grts, \\\n                      'img': img, 'img_grt': img_grt}\n            #np.save('search1.npy',search)\n            if self.seq_name is not None:\n                fname = os.path.join(self.seq_name, \"%05d\" % idx)\n                sample['fname'] = fname\n\n            if self.transform is not None:\n                sample = self.transform(sample)\n       \n        else:\n            img, gt = self.make_video_gt_pair(idx)\n            sample = {'image': img, 'gt': gt}\n            if self.seq_name is not None:\n                fname = os.path.join(self.seq_name, \"%05d\" % idx)\n                sample['fname'] = fname\n        \n        \n        \n        return sample  #这个类最后的输出\n\n    def make_video_gt_pair(self, idx): #这个函数存在的意义是为了getitem函数服务的\n        \"\"\"\n        Make the image-ground-truth pair\n        \"\"\"\n        img = cv2.imread(os.path.join(self.db_root_dir, self.video_list[idx]), cv2.IMREAD_COLOR)\n        if self.labels[idx] is not None and self.train:\n            label = cv2.imread(os.path.join(self.db_root_dir, self.labels[idx]), cv2.IMREAD_GRAYSCALE)\n            #print(os.path.join(self.db_root_dir, self.labels[idx]))\n        else:\n            gt = np.zeros(img.shape[:-1], dtype=np.uint8)\n            \n         ## 已经读取了image以及对应的ground truth可以进行data augmentation了\n        if self.train:  #scaling, cropping and flipping\n             img, label = my_crop(img,label)\n             scale = random.uniform(0.7, 1.3)\n             flip_p = random.uniform(0, 1)\n             img_temp = scale_im(img,scale)\n             img_temp = flip(img_temp,flip_p)\n             gt_temp = scale_gt(label,scale)\n             gt_temp = flip(gt_temp,flip_p)\n             \n             img = img_temp\n             label = gt_temp\n             \n        if self.inputRes is not None:\n            img = imresize(img, self.inputRes)\n            #print('ok1')\n            #scipy.misc.imsave('label.png',label)\n            #scipy.misc.imsave('img.png',img)\n            if self.labels[idx] is not None and self.train:\n                label = imresize(label, self.inputRes, interp='nearest')\n\n        img = np.array(img, dtype=np.float32)\n        #img = img[:, :, ::-1]\n        img = np.subtract(img, np.array(self.meanval, dtype=np.float32))        \n        img = img.transpose((2, 0, 1))  # NHWC -> NCHW\n        \n        if self.labels[idx] is not None and self.train:\n                gt = np.array(label, dtype=np.int32)\n                gt[gt!=0]=1\n                #gt = gt/np.max([gt.max(), 1e-8])\n        #np.save('gt.npy')\n        return img, gt\n\n    def get_img_size(self):\n        img = cv2.imread(os.path.join(self.db_root_dir, self.video_list[0]))\n        \n        return list(img.shape[:2])\n\n    def make_img_gt_pair(self, idx): #这个函数存在的意义是为了getitem函数服务的\n        \"\"\"\n        Make the image-ground-truth pair\n        \"\"\"\n        img = cv2.imread(os.path.join(self.img_root_dir, self.image_list[idx]),cv2.IMREAD_COLOR)\n        #print(os.path.join(self.db_root_dir, self.img_list[idx]))\n        if self.img_labels[idx] is not None and self.train:\n            label = cv2.imread(os.path.join(self.img_root_dir, self.img_labels[idx]),cv2.IMREAD_GRAYSCALE)\n            #print(os.path.join(self.db_root_dir, self.labels[idx]))\n        else:\n            gt = np.zeros(img.shape[:-1], dtype=np.uint8)\n            \n        if self.inputRes is not None:            \n            img = imresize(img, self.inputRes)\n            if self.img_labels[idx] is not None and self.train:\n                label = imresize(label, self.inputRes, interp='nearest')\n\n        img = np.array(img, dtype=np.float32)\n        #img = img[:, :, ::-1]\n        img = np.subtract(img, np.array(self.meanval, dtype=np.float32))        \n        img = img.transpose((2, 0, 1))  # NHWC -> NCHW\n        \n        if self.img_labels[idx] is not None and self.train:\n                gt = np.array(label, dtype=np.int32)\n                gt[gt!=0]=1\n                #gt = gt/np.max([gt.max(), 1e-8])\n        #np.save('gt.npy')\n        return img, gt\n    \nif __name__ == '__main__':\n    import custom_transforms as tr\n    import torch\n    from torchvision import transforms\n    from matplotlib import pyplot as plt\n\n    transforms = transforms.Compose([tr.RandomHorizontalFlip(), tr.Resize(scales=[0.5, 0.8, 1]), tr.ToTensor()])\n\n    #dataset = DAVIS2016(db_root_dir='/media/eec/external/Databases/Segmentation/DAVIS-2016',\n                       # train=True, transform=transforms)\n    #dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=1)\n#\n#    for i, data in enumerate(dataloader):\n#        plt.figure()\n#        plt.imshow(overlay_mask(im_normalize(tens2image(data['image'])), tens2image(data['gt'])))\n#        if i == 10:\n#            break\n#\n#    plt.show(block=True)"
  },
  {
    "path": "dataloaders/r",
    "content": "\n"
  },
  {
    "path": "deeplab/__init__.py",
    "content": "\n"
  },
  {
    "path": "deeplab/e",
    "content": "\n"
  },
  {
    "path": "deeplab/siamese_model_conf.py",
    "content": "# -*- coding: utf-8 -*-\r\n\"\"\"\r\nCreated on Sun Sep 16 10:01:14 2018\r\n\r\n@author: carri\r\n\"\"\"\r\n\r\nimport torch.nn as nn\r\nimport torch\r\nimport torch.nn.functional as F\r\nfrom torch.nn import init\r\naffine_par = True\r\n#区别于siamese_model_concat的地方就是采用的最标准的deeplab_v3的基础网络，然后加上了非对称的分支\r\n\r\ndef conv3x3(in_planes, out_planes, stride=1):\r\n    \"\"\"3x3 convolution with padding\"\"\"\r\n    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,\r\n                     padding=1, bias=False)\r\n\r\n\r\nclass BasicBlock(nn.Module):\r\n    expansion = 1\r\n\r\n    def __init__(self, inplanes, planes, stride=1, downsample=None):\r\n        super(BasicBlock, self).__init__()\r\n        self.conv1 = conv3x3(inplanes, planes, stride)\r\n        self.bn1 = nn.BatchNorm2d(planes, affine=affine_par)\r\n        self.relu = nn.ReLU(inplace=True)\r\n        self.conv2 = conv3x3(planes, planes)\r\n        self.bn2 = nn.BatchNorm2d(planes, affine=affine_par)\r\n        self.downsample = downsample\r\n        self.stride = stride\r\n\r\n    def forward(self, x):\r\n        residual = x\r\n\r\n        out = self.conv1(x)\r\n        out = self.bn1(out)\r\n        out = self.relu(out)\r\n\r\n        out = self.conv2(out)\r\n        out = self.bn2(out)\r\n\r\n        if self.downsample is not None:\r\n            residual = self.downsample(x)\r\n\r\n        out += residual\r\n        out = self.relu(out)\r\n\r\n        return out\r\n\r\n\r\nclass Bottleneck(nn.Module):\r\n    expansion = 4\r\n\r\n    def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None):\r\n        super(Bottleneck, self).__init__()\r\n        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False)  # change\r\n        self.bn1 = nn.BatchNorm2d(planes, affine=affine_par)\r\n        padding = dilation\r\n        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1,  # change\r\n                               padding=padding, bias=False, dilation=dilation)\r\n        self.bn2 = nn.BatchNorm2d(planes, affine=affine_par)\r\n        self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)\r\n        self.bn3 = nn.BatchNorm2d(planes * 4, affine=affine_par)\r\n        self.relu = nn.ReLU(inplace=True)\r\n        self.downsample = downsample\r\n        self.stride = stride\r\n\r\n    def forward(self, x):\r\n        residual = x\r\n\r\n        out = self.conv1(x)\r\n        out = self.bn1(out)\r\n        out = self.relu(out)\r\n\r\n        out = self.conv2(out)\r\n        out = self.bn2(out)\r\n        out = self.relu(out)\r\n\r\n        out = self.conv3(out)\r\n        out = self.bn3(out)\r\n\r\n        if self.downsample is not None:\r\n            residual = self.downsample(x)\r\n\r\n        out += residual\r\n        out = self.relu(out)\r\n\r\n        return out\r\n\r\n\r\nclass ASPP(nn.Module):\r\n    def __init__(self, dilation_series, padding_series, depth):\r\n        super(ASPP, self).__init__()\r\n        self.mean = nn.AdaptiveAvgPool2d((1,1))\r\n        self.conv= nn.Conv2d(2048, depth, 1,1)\r\n        self.bn_x = nn.BatchNorm2d(depth)\r\n        self.conv2d_0 = nn.Conv2d(2048, depth, kernel_size=1, stride=1)\r\n        self.bn_0 = nn.BatchNorm2d(depth)\r\n        self.conv2d_1 = nn.Conv2d(2048, depth, kernel_size=3, stride=1, padding=padding_series[0], dilation=dilation_series[0])\r\n        self.bn_1 = nn.BatchNorm2d(depth)\r\n        self.conv2d_2 = nn.Conv2d(2048, depth, kernel_size=3, stride=1, padding=padding_series[1], dilation=dilation_series[1])\r\n        self.bn_2 = nn.BatchNorm2d(depth)\r\n        self.conv2d_3 = nn.Conv2d(2048, depth, kernel_size=3, stride=1, padding=padding_series[2], dilation=dilation_series[2])\r\n        self.bn_3 = nn.BatchNorm2d(depth)\r\n        self.relu = nn.ReLU(inplace=True)\r\n        self.bottleneck = nn.Conv2d( depth*5, 256, kernel_size=3, padding=1 )  #512 1x1Conv\r\n        self.bn = nn.BatchNorm2d(256)\r\n        self.prelu = nn.PReLU()\r\n        #for m in self.conv2d_list:\r\n        #    m.weight.data.normal_(0, 0.01)\r\n        for m in self.modules():\r\n            if isinstance(m, nn.Conv2d):\r\n                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels\r\n                m.weight.data.normal_(0, 0.01)\r\n            elif isinstance(m, nn.BatchNorm2d):\r\n                m.weight.data.fill_(1)\r\n                m.bias.data.zero_()\r\n            \r\n    def _make_stage_(self, dilation1, padding1):\r\n        Conv = nn.Conv2d(2048, 256, kernel_size=3, stride=1, padding=padding1, dilation=dilation1, bias=True)#classes\r\n        Bn = nn.BatchNorm2d(256)\r\n        Relu = nn.ReLU(inplace=True)\r\n        return nn.Sequential(Conv, Bn, Relu)\r\n        \r\n\r\n    def forward(self, x):\r\n        #out = self.conv2d_list[0](x)\r\n        #mulBranches = [conv2d_l(x) for conv2d_l in self.conv2d_list]\r\n        size=x.shape[2:]\r\n        image_features=self.mean(x)\r\n        image_features=self.conv(image_features)\r\n        image_features = self.bn_x(image_features)\r\n        image_features = self.relu(image_features)\r\n        image_features=F.upsample(image_features, size=size, mode='bilinear', align_corners=True)\r\n        out_0 = self.conv2d_0(x)\r\n        out_0 = self.bn_0(out_0) \r\n        out_0 = self.relu(out_0)\r\n        out_1 = self.conv2d_1(x)\r\n        out_1 = self.bn_1(out_1) \r\n        out_1 = self.relu(out_1)\r\n        out_2 = self.conv2d_2(x)\r\n        out_2 = self.bn_2(out_2) \r\n        out_2 = self.relu(out_2)\r\n        out_3 = self.conv2d_3(x)\r\n        out_3 = self.bn_3(out_3) \r\n        out_3 = self.relu(out_3)\r\n        out = torch.cat([image_features, out_0, out_1, out_2, out_3], 1)\r\n        out = self.bottleneck(out)\r\n        out = self.bn(out)\r\n        out = self.prelu(out)\r\n        #for i in range(len(self.conv2d_list) - 1):\r\n        #    out += self.conv2d_list[i + 1](x)\r\n        \r\n        return out\r\n  \r\n\r\n\r\nclass ResNet(nn.Module):\r\n    def __init__(self, block, layers, num_classes):\r\n        self.inplanes = 64\r\n        super(ResNet, self).__init__()\r\n        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)\r\n        self.bn1 = nn.BatchNorm2d(64, affine=affine_par)\r\n        self.relu = nn.ReLU(inplace=True)\r\n        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=True)  # change\r\n        self.layer1 = self._make_layer(block, 64, layers[0])\r\n        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)\r\n        self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2)\r\n        self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4)\r\n        self.layer5 = self._make_pred_layer(ASPP, [ 6, 12, 18], [6, 12, 18], 512)\r\n        self.main_classifier = nn.Conv2d(256, num_classes, kernel_size=1)\r\n        self.softmax = nn.Sigmoid()#nn.Softmax()\r\n        \r\n        for m in self.modules():\r\n            if isinstance(m, nn.Conv2d):\r\n                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels\r\n                m.weight.data.normal_(0, 0.01)\r\n            elif isinstance(m, nn.BatchNorm2d):\r\n                m.weight.data.fill_(1)\r\n                m.bias.data.zero_()\r\n\r\n    def _make_layer(self, block, planes, blocks, stride=1, dilation=1):\r\n        downsample = None\r\n        if stride != 1 or self.inplanes != planes * block.expansion or dilation == 2 or dilation == 4:\r\n            downsample = nn.Sequential(\r\n                nn.Conv2d(self.inplanes, planes * block.expansion,\r\n                          kernel_size=1, stride=stride, bias=False),\r\n                nn.BatchNorm2d(planes * block.expansion, affine=affine_par))\r\n        for i in downsample._modules['1'].parameters():\r\n            i.requires_grad = False\r\n        layers = []\r\n        layers.append(block(self.inplanes, planes, stride, dilation=dilation, downsample=downsample))\r\n        self.inplanes = planes * block.expansion\r\n        for i in range(1, blocks):\r\n            layers.append(block(self.inplanes, planes, dilation=dilation))\r\n\r\n        return nn.Sequential(*layers)\r\n\r\n    def _make_pred_layer(self, block, dilation_series, padding_series, num_classes):\r\n        return block(dilation_series, padding_series, num_classes)\r\n\r\n    def forward(self, x):\r\n        input_size = x.size()[2:]\r\n        x = self.conv1(x)\r\n        x = self.bn1(x)\r\n        x = self.relu(x)\r\n        x = self.maxpool(x)\r\n        x = self.layer1(x)\r\n        x = self.layer2(x)\r\n        x = self.layer3(x)\r\n        x = self.layer4(x)\r\n        fea = self.layer5(x)\r\n        x = self.main_classifier(fea)\r\n        #print(\"before upsample, tensor size:\", x.size())\r\n        x = F.upsample(x, input_size, mode='bilinear')  #upsample to the size of input image, scale=8\r\n        #print(\"after upsample, tensor size:\", x.size())\r\n        x = self.softmax(x)\r\n        return fea, x\r\n\r\nclass CoattentionModel(nn.Module):\r\n    def  __init__(self, block, layers, num_classes, all_channel=256, all_dim=60*60):\t#473./8=60\t\r\n        super(CoattentionModel, self).__init__()\r\n        self.encoder = ResNet(block, layers, num_classes)\r\n        self.linear_e = nn.Linear(all_channel, all_channel,bias = False)\r\n        self.channel = all_channel\r\n        self.dim = all_dim\r\n        self.gate = nn.Conv2d(all_channel, 1, kernel_size  = 1, bias = False)\r\n        self.gate_s = nn.Sigmoid()\r\n        self.conv1 = nn.Conv2d(all_channel*2, all_channel, kernel_size=3, padding=1, bias = False)\r\n        self.conv2 = nn.Conv2d(all_channel*2, all_channel, kernel_size=3, padding=1, bias = False)\r\n        self.bn1 = nn.BatchNorm2d(all_channel)\r\n        self.bn2 = nn.BatchNorm2d(all_channel)\r\n        self.prelu = nn.ReLU(inplace=True)\r\n        self.main_classifier1 = nn.Conv2d(all_channel, num_classes, kernel_size=1, bias = True)\r\n        self.main_classifier2 = nn.Conv2d(all_channel, num_classes, kernel_size=1, bias = True)\r\n        self.softmax = nn.Sigmoid()\r\n        \r\n        for m in self.modules():\r\n            if isinstance(m, nn.Conv2d):\r\n                #n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels\r\n                m.weight.data.normal_(0, 0.01)\r\n                #init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')\r\n                #init.xavier_normal(m.weight.data)\r\n                #m.bias.data.fill_(0)\r\n            elif isinstance(m, nn.BatchNorm2d):\r\n                m.weight.data.fill_(1)\r\n                m.bias.data.zero_()    \r\n    \r\n\t\t\r\n    def forward(self, input1, input2): #注意input2 可以是多帧图像\r\n        \r\n        #input1_att, input2_att = self.coattention(input1, input2) \r\n        input_size = input1.size()[2:]\r\n        exemplar, temp = self.encoder(input1)\r\n        query, temp = self.encoder(input2)\t\t \r\n        fea_size = query.size()[2:]\t \r\n        all_dim = fea_size[0]*fea_size[1]\r\n        exemplar_flat = exemplar.view(-1, query.size()[1], all_dim) #N,C,H*W\r\n        query_flat = query.view(-1, query.size()[1], all_dim)\r\n        exemplar_t = torch.transpose(exemplar_flat,1,2).contiguous()  #batch size x dim x num\r\n        exemplar_corr = self.linear_e(exemplar_t) # \r\n        A = torch.bmm(exemplar_corr, query_flat)\r\n        A1 = F.softmax(A.clone(), dim = 1) #\r\n        B = F.softmax(torch.transpose(A,1,2),dim=1)\r\n        query_att = torch.bmm(exemplar_flat, A1).contiguous() #注意我们这个地方要不要用交互以及Residual的结构\r\n        exemplar_att = torch.bmm(query_flat, B).contiguous()\r\n        \r\n        input1_att = exemplar_att.view(-1, query.size()[1], fea_size[0], fea_size[1])  \r\n        input2_att = query_att.view(-1, query.size()[1], fea_size[0], fea_size[1])\r\n        input1_mask = self.gate(input1_att)\r\n        input2_mask = self.gate(input2_att)\r\n        input1_mask = self.gate_s(input1_mask)\r\n        input2_mask = self.gate_s(input2_mask)\r\n        input1_att = input1_att * input1_mask\r\n        input2_att = input2_att * input2_mask\r\n        input1_att = torch.cat([input1_att, exemplar],1) \r\n        input2_att = torch.cat([input2_att, query],1)\r\n        input1_att  = self.conv1(input1_att )\r\n        input2_att  = self.conv2(input2_att ) \r\n        input1_att  = self.bn1(input1_att )\r\n        input2_att  = self.bn2(input2_att )\r\n        input1_att  = self.prelu(input1_att )\r\n        input2_att  = self.prelu(input2_att )\r\n        x1 = self.main_classifier1(input1_att)\r\n        x2 = self.main_classifier2(input2_att)   \r\n        x1 = F.upsample(x1, input_size, mode='bilinear')  #upsample to the size of input image, scale=8\r\n        x2 = F.upsample(x2, input_size, mode='bilinear')  #upsample to the size of input image, scale=8\r\n        #print(\"after upsample, tensor size:\", x.size())\r\n        x1 = self.softmax(x1)\r\n        x2 = self.softmax(x2)\r\n        \r\n#        x1 = self.softmax(x1)\r\n#        x2 = self.softmax(x2)\r\n        return x1, x2, temp  #shape: NxCx\t\r\n    \r\n\r\ndef Res_Deeplab(num_classes=2):\r\n    model = ResNet(Bottleneck, [3, 4, 23, 3], num_classes-1)\r\n    return model\r\n\r\ndef CoattentionNet(num_classes=2):\r\n    model = CoattentionModel(Bottleneck,[3, 4, 23, 3], num_classes-1)\r\n\t\r\n    return model\r\n"
  },
  {
    "path": "deeplab/siamese_model_conf_try_single.py",
    "content": "# -*- coding: utf-8 -*-\r\n\"\"\"\r\nCreated on Sun Sep 16 10:01:14 2018\r\n\r\n@author: carri\r\n\"\"\"\r\n\r\nimport torch.nn as nn\r\nimport torch\r\nimport torch.nn.functional as F\r\nfrom torch.nn import init\r\naffine_par = True\r\nimport numpy as np\r\n#区别于siamese_model_concat的地方就是采用的最标准的deeplab_v3的基础网络，然后加上了非对称的分支\r\n\r\ndef conv3x3(in_planes, out_planes, stride=1):\r\n    \"\"\"3x3 convolution with padding\"\"\"\r\n    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,\r\n                     padding=1, bias=False)\r\n\r\n\r\nclass BasicBlock(nn.Module):\r\n    expansion = 1\r\n\r\n    def __init__(self, inplanes, planes, stride=1, downsample=None):\r\n        super(BasicBlock, self).__init__()\r\n        self.conv1 = conv3x3(inplanes, planes, stride)\r\n        self.bn1 = nn.BatchNorm2d(planes, affine=affine_par)\r\n        self.relu = nn.ReLU(inplace=True)\r\n        self.conv2 = conv3x3(planes, planes)\r\n        self.bn2 = nn.BatchNorm2d(planes, affine=affine_par)\r\n        self.downsample = downsample\r\n        self.stride = stride\r\n\r\n    def forward(self, x):\r\n        residual = x\r\n\r\n        out = self.conv1(x)\r\n        out = self.bn1(out)\r\n        out = self.relu(out)\r\n\r\n        out = self.conv2(out)\r\n        out = self.bn2(out)\r\n\r\n        if self.downsample is not None:\r\n            residual = self.downsample(x)\r\n\r\n        out += residual\r\n        out = self.relu(out)\r\n\r\n        return out\r\n\r\n\r\nclass Bottleneck(nn.Module):\r\n    expansion = 4\r\n\r\n    def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None):\r\n        super(Bottleneck, self).__init__()\r\n        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False)  # change\r\n        self.bn1 = nn.BatchNorm2d(planes, affine=affine_par)\r\n        padding = dilation\r\n        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1,  # change\r\n                               padding=padding, bias=False, dilation=dilation)\r\n        self.bn2 = nn.BatchNorm2d(planes, affine=affine_par)\r\n        self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)\r\n        self.bn3 = nn.BatchNorm2d(planes * 4, affine=affine_par)\r\n        self.relu = nn.ReLU(inplace=True)\r\n        self.downsample = downsample\r\n        self.stride = stride\r\n\r\n    def forward(self, x):\r\n        residual = x\r\n\r\n        out = self.conv1(x)\r\n        out = self.bn1(out)\r\n        out = self.relu(out)\r\n\r\n        out = self.conv2(out)\r\n        out = self.bn2(out)\r\n        out = self.relu(out)\r\n\r\n        out = self.conv3(out)\r\n        out = self.bn3(out)\r\n\r\n        if self.downsample is not None:\r\n            residual = self.downsample(x)\r\n\r\n        out += residual\r\n        out = self.relu(out)\r\n\r\n        return out\r\n\r\n\r\nclass ASPP(nn.Module):\r\n    def __init__(self, dilation_series, padding_series, depth):\r\n        super(ASPP, self).__init__()\r\n        self.mean = nn.AdaptiveAvgPool2d((1,1))\r\n        self.conv= nn.Conv2d(2048, depth, 1,1)\r\n        self.bn_x = nn.BatchNorm2d(depth)\r\n        self.conv2d_0 = nn.Conv2d(2048, depth, kernel_size=1, stride=1)\r\n        self.bn_0 = nn.BatchNorm2d(depth)\r\n        self.conv2d_1 = nn.Conv2d(2048, depth, kernel_size=3, stride=1, padding=padding_series[0], dilation=dilation_series[0])\r\n        self.bn_1 = nn.BatchNorm2d(depth)\r\n        self.conv2d_2 = nn.Conv2d(2048, depth, kernel_size=3, stride=1, padding=padding_series[1], dilation=dilation_series[1])\r\n        self.bn_2 = nn.BatchNorm2d(depth)\r\n        self.conv2d_3 = nn.Conv2d(2048, depth, kernel_size=3, stride=1, padding=padding_series[2], dilation=dilation_series[2])\r\n        self.bn_3 = nn.BatchNorm2d(depth)\r\n        self.relu = nn.ReLU(inplace=True)\r\n        self.bottleneck = nn.Conv2d( depth*5, 256, kernel_size=3, padding=1 )  #512 1x1Conv\r\n        self.bn = nn.BatchNorm2d(256)\r\n        self.prelu = nn.PReLU()\r\n        #for m in self.conv2d_list:\r\n        #    m.weight.data.normal_(0, 0.01)\r\n        for m in self.modules():\r\n            if isinstance(m, nn.Conv2d):\r\n                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels\r\n                m.weight.data.normal_(0, 0.01)\r\n            elif isinstance(m, nn.BatchNorm2d):\r\n                m.weight.data.fill_(1)\r\n                m.bias.data.zero_()\r\n            \r\n    def _make_stage_(self, dilation1, padding1):\r\n        Conv = nn.Conv2d(2048, 256, kernel_size=3, stride=1, padding=padding1, dilation=dilation1, bias=True)#classes\r\n        Bn = nn.BatchNorm2d(256)\r\n        Relu = nn.ReLU(inplace=True)\r\n        return nn.Sequential(Conv, Bn, Relu)\r\n        \r\n\r\n    def forward(self, x):\r\n        #out = self.conv2d_list[0](x)\r\n        #mulBranches = [conv2d_l(x) for conv2d_l in self.conv2d_list]\r\n        size=x.shape[2:]\r\n        image_features=self.mean(x)\r\n        image_features=self.conv(image_features)\r\n        image_features = self.bn_x(image_features)\r\n        image_features = self.relu(image_features)\r\n        image_features=F.upsample(image_features, size=size, mode='bilinear', align_corners=True)\r\n        out_0 = self.conv2d_0(x)\r\n        out_0 = self.bn_0(out_0) \r\n        out_0 = self.relu(out_0)\r\n        out_1 = self.conv2d_1(x)\r\n        out_1 = self.bn_1(out_1) \r\n        out_1 = self.relu(out_1)\r\n        out_2 = self.conv2d_2(x)\r\n        out_2 = self.bn_2(out_2) \r\n        out_2 = self.relu(out_2)\r\n        out_3 = self.conv2d_3(x)\r\n        out_3 = self.bn_3(out_3) \r\n        out_3 = self.relu(out_3)\r\n        out = torch.cat([image_features, out_0, out_1, out_2, out_3], 1)\r\n        out = self.bottleneck(out)\r\n        out = self.bn(out)\r\n        out = self.prelu(out)\r\n        #for i in range(len(self.conv2d_list) - 1):\r\n        #    out += self.conv2d_list[i + 1](x)\r\n        \r\n        return out\r\n  \r\n\r\n\r\nclass ResNet(nn.Module):\r\n    def __init__(self, block, layers, num_classes):\r\n        self.inplanes = 64\r\n        super(ResNet, self).__init__()\r\n        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)\r\n        self.bn1 = nn.BatchNorm2d(64, affine=affine_par)\r\n        self.relu = nn.ReLU(inplace=True)\r\n        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=True)  # change\r\n        self.layer1 = self._make_layer(block, 64, layers[0])\r\n        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)\r\n        self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2)\r\n        self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4)\r\n        self.layer5 = self._make_pred_layer(ASPP, [ 6, 12, 18], [6, 12, 18], 512)\r\n        self.main_classifier = nn.Conv2d(256, num_classes, kernel_size=1)\r\n        self.softmax = nn.Sigmoid()#nn.Softmax()\r\n        \r\n        for m in self.modules():\r\n            if isinstance(m, nn.Conv2d):\r\n                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels\r\n                m.weight.data.normal_(0, 0.01)\r\n            elif isinstance(m, nn.BatchNorm2d):\r\n                m.weight.data.fill_(1)\r\n                m.bias.data.zero_()\r\n\r\n    def _make_layer(self, block, planes, blocks, stride=1, dilation=1):\r\n        downsample = None\r\n        if stride != 1 or self.inplanes != planes * block.expansion or dilation == 2 or dilation == 4:\r\n            downsample = nn.Sequential(\r\n                nn.Conv2d(self.inplanes, planes * block.expansion,\r\n                          kernel_size=1, stride=stride, bias=False),\r\n                nn.BatchNorm2d(planes * block.expansion, affine=affine_par))\r\n        for i in downsample._modules['1'].parameters():\r\n            i.requires_grad = False\r\n        layers = []\r\n        layers.append(block(self.inplanes, planes, stride, dilation=dilation, downsample=downsample))\r\n        self.inplanes = planes * block.expansion\r\n        for i in range(1, blocks):\r\n            layers.append(block(self.inplanes, planes, dilation=dilation))\r\n\r\n        return nn.Sequential(*layers)\r\n\r\n    def _make_pred_layer(self, block, dilation_series, padding_series, num_classes):\r\n        return block(dilation_series, padding_series, num_classes)\r\n\r\n    def forward(self, x):\r\n        input_size = x.size()[2:]\r\n        x = self.conv1(x)\r\n        x = self.bn1(x)\r\n        x = self.relu(x)\r\n        x = self.maxpool(x)\r\n        x = self.layer1(x)\r\n        x = self.layer2(x)\r\n        x = self.layer3(x)\r\n        x = self.layer4(x)\r\n        fea = self.layer5(x)\r\n        x = self.main_classifier(fea)\r\n        #print(\"before upsample, tensor size:\", x.size())\r\n        x = F.upsample(x, input_size, mode='bilinear')  #upsample to the size of input image, scale=8\r\n        #print(\"after upsample, tensor size:\", x.size())\r\n        x = self.softmax(x)\r\n        return fea, x\r\n\r\nclass CoattentionModel(nn.Module):\r\n    def  __init__(self, block, layers, num_classes,  all_channel=256, all_dim=60*60):\t#473./8=60\r\n        super(CoattentionModel, self).__init__()\r\n        self.nframes = 2\r\n        self.encoder = ResNet(block, layers, num_classes)\r\n        self.linear_e = nn.Linear(all_channel, all_channel,bias = False)\r\n        self.channel = all_channel\r\n        self.dim = all_dim\r\n        self.gate = nn.Conv2d(all_channel, 1, kernel_size  = 1, bias = False)\r\n        self.gate_s = nn.Sigmoid()\r\n        self.conv1 = nn.Conv2d(all_channel*2, all_channel, kernel_size=3, padding=1, bias = False)\r\n        self.conv2 = nn.Conv2d(all_channel*2, all_channel, kernel_size=3, padding=1, bias = False)\r\n        self.bn1 = nn.BatchNorm2d(all_channel, affine=affine_par)\r\n        self.bn2 = nn.BatchNorm2d(all_channel, affine=affine_par)\r\n        self.prelu = nn.ReLU(inplace=True)\r\n        self.main_classifier1 = nn.Conv2d(all_channel, num_classes, kernel_size=1, bias = True)\r\n        self.main_classifier2 = nn.Conv2d(all_channel, num_classes, kernel_size=1, bias = True)\r\n        self.softmax = nn.Sigmoid()\r\n\r\n        for m in self.modules():\r\n            if isinstance(m, nn.Conv2d):\r\n                #n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels\r\n                m.weight.data.normal_(0, 0.01)\r\n                #init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')\r\n                #init.xavier_normal(m.weight.data)\r\n                #m.bias.data.fill_(0)\r\n            elif isinstance(m, nn.BatchNorm2d):\r\n                m.weight.data.fill_(1)\r\n                m.bias.data.zero_()    \r\n    \r\n\t\t\r\n    def forward(self, input1, input2): #注意input2 可以是多帧图像\r\n        \r\n        #input1_att, input2_att = self.coattention(input1, input2)\r\n        exemplar, temp = self.encoder(input1)\r\n\r\n        #print('feature size:', input1.size())\r\n        if len(input2.size() )>4:\r\n            B, N, C, H, W = input2.size()  # 2,2,3,473,473\r\n            input_size = [H, W]\r\n            video_frames2 = [elem.view(B, C, H, W) for elem in input2.split(1, dim=1)]\r\n            # the length of exemplars is equal to the nframes\r\n            querys= [self.encoder(video_frames2[frame]) for frame in range(0,self.nframes)]\r\n            #query = torch.cat([querys[][0]],dim=1)\r\n            #query1 = torch.cat([querys[1]], dim=1)\r\n            query = torch.cat(([querys[frame][0] for frame in range(0,self.nframes )]), dim=2)\r\n            #print('query size:', query.size()) 2*512*49*49\r\n            predict_mask = torch.cat(([querys[frame][1] for frame in range(0,self.nframes )]), dim=1)\r\n            #print('feature size:', exemplar.size())\r\n            fea_size = exemplar.size()[2:]\r\n            exemplar_flat = exemplar.view(-1, self.channel, fea_size[0]*fea_size[1]) #N,C,H*W\r\n            exemplar_t = torch.transpose(exemplar_flat, 1, 2).contiguous()  # batch size x dim x num\r\n            exemplar_corr = self.linear_e(exemplar_t)\r\n            #coattention_fea = 0\r\n            query_flat = query.view(-1, self.channel, self.nframes*fea_size[0]*fea_size[1])\r\n            A = torch.bmm(exemplar_corr, query_flat)\r\n            A = F.softmax(A, dim = 1) #\r\n            B = F.softmax(torch.transpose(A,1,2),dim=1)\r\n            query_att = torch.bmm(exemplar_flat, A).contiguous() #注意我们这个地方要不要用交互以及Residual的结构\r\n            exemplar_att = torch.bmm(query_flat, B).contiguous()\r\n\r\n            input1_att = exemplar_att.view(-1, self.channel, fea_size[0], fea_size[1])\r\n            input2_att = query_att.view(-1, self.channel, self.nframes*fea_size[0], fea_size[1])\r\n            input1_mask = self.gate(input1_att)\r\n            #input2_mask = self.gate(input2_att)\r\n            input1_mask = self.gate_s(input1_mask)\r\n            #input2_mask = self.gate_s(input2_mask)\r\n            input1_att_org = input1_att * input1_mask\r\n            #coattention_fea = coattention_fea + input1_att_org\r\n\r\n            #print('h_v size, h_v_org size:', torch.max(input1_att), torch.max(exemplar), torch.min(input1_att), torch.max(exemplar))\r\n            input1_att = torch.cat([input1_att_org, exemplar],1)\r\n            input1_att  = self.conv1(input1_att )\r\n            input1_att  = self.bn1(input1_att )\r\n            input1_att  = self.prelu(input1_att )\r\n            x1 = self.main_classifier1(input1_att)\r\n            x1 = F.upsample(x1, input_size, mode='bilinear')  #upsample to the size of input image, scale=8\r\n              #upsample to the size of input image, scale=8\r\n            #print(\"after upsample, tensor size:\", x.size())\r\n            x1  = self.softmax(x1)\r\n        else:\r\n            x1 = exemplar\r\n\r\n        return x1, temp  #shape: NxCx\r\n\r\n\r\ndef CoattentionNet(num_classes=2,nframes=2):\r\n    model = CoattentionModel(Bottleneck,[3, 4, 23, 3], num_classes-1)\r\n\t\r\n    return model"
  },
  {
    "path": "deeplab/utils.py",
    "content": "import torch\n#from tensorboard_logger import log_value\nfrom torch.autograd import Variable\n\n\ndef loss_calc(pred, label, ignore_label):\n    \"\"\"\n    This function returns cross entropy loss for semantic segmentation\n    \"\"\"\n    # out shape batch_size x channels x h x w -> batch_size x channels x h x w\n    # label shape h x w x 1 x batch_size  -> batch_size x 1 x h x w\n    label = Variable(label.long()).cuda()\n    criterion = torch.nn.CrossEntropyLoss(ignore_index=ignore_label).cuda()\n\n    return criterion(pred, label)\n\n\ndef lr_poly(base_lr, iter, max_iter, power):\n    return base_lr * ((1 - float(iter) / max_iter) ** power)\n\n\ndef get_1x_lr_params(model):\n    \"\"\"\n    This generator returns all the parameters of the net except for\n    the last classification layer. Note that for each batchnorm layer,\n    requires_grad is set to False in deeplab_resnet.py, therefore this function does not return\n    any batchnorm parameter\n    \"\"\"\n    b = [model.conv1, model.bn1, model.layer1, model.layer2, model.layer3, model.layer4]\n\n    for i in range(len(b)):\n        for j in b[i].modules():\n            jj = 0\n            for k in j.parameters():\n                jj += 1\n                if k.requires_grad:\n                    yield k\n\n\ndef get_10x_lr_params(model):\n    \"\"\"\n    This generator returns all the parameters for the last layer of the net,\n    which does the classification of pixel into classes\n    \"\"\"\n    b = [model.layer5.parameters(), model.main_classifier.parameters()]\n\n    for j in range(len(b)):\n        for i in b[j]:\n            yield i\n\n\ndef adjust_learning_rate(optimizer, i_iter, learning_rate, num_steps, power):\n    \"\"\"Sets the learning rate to the initial LR divided by 5 at 60th, 120th and 160th epochs\"\"\"\n    lr = lr_poly(learning_rate, i_iter, num_steps, power)\n    #log_value('learning', lr, i_iter)\n    optimizer.param_groups[0]['lr'] = lr\n    optimizer.param_groups[1]['lr'] = lr * 10\n"
  },
  {
    "path": "densecrf_apply_cvpr2019.py",
    "content": "#!/usr/bin/env python3\n# -*- coding: utf-8 -*-\n\"\"\"\nCreated on Fri Mar  1 20:37:37 2019\n\n@author: xiankai\n\"\"\"\n\nimport pydensecrf.densecrf as dcrf\nimport numpy as np\nimport sys\nimport os\n\n\nfrom skimage.io import imread, imsave\nfrom pydensecrf.utils import unary_from_labels, create_pairwise_bilateral, create_pairwise_gaussian, unary_from_softmax\nfrom os import listdir, makedirs\nfrom os.path import isfile, join\nfrom multiprocessing import Process\n\n            \ndef worker(scale, g_dim, g_factor,s_dim,C_dim,c_factor):\n    davis_path = '/home/xiankai/work/DAVIS-2016/JPEGImages/480p'#'/home/ying/tracking/pdb_results/FBMS-results'\n    origin_path = '/home/xiankai/work/DAVIS-2016/Results/Segmentations/480p/COS-78.2'#'/home/xiankai/work/DAVIS-2016/Results/Segmentations/480p/ECCV'#'/media/xiankai/Data/segmentation/match-Weaksup_VideoSeg/result/test/davis_iteration_conf_sal_match_scale/COS/'\n    out_folder = '/home/xiankai/work/DAVIS-2016/Results/Segmentations/480p/cvpr2019_crfs'#'/media/xiankai/Data/ECCV-crf'#'/home/xiankai/work/DAVIS-2016/Results/Segmentations/480p/davis_ICCV_new/'\n    if not os.path.exists(out_folder):\n        os.makedirs(out_folder)\n    origin_file = listdir(origin_path)\n    origin_file.sort()\n    for i in range(0, len(origin_file)):\n        d = origin_file[i]\n        vidDir = join(davis_path, d)\n        out_folder1 = join(out_folder,'f'+str(scale)+str(g_dim)+str(g_factor)+'_'+'s'+str(s_dim)+'_'+'c'+str(C_dim)+str(c_factor))\n        resDir = join(out_folder1, d)\n        if not os.path.exists(resDir):\n                os.makedirs(resDir)\n        rgb_file = listdir(vidDir)\n        rgb_file.sort()\n        for ii in range(0,len(rgb_file)):  \n            f = rgb_file[ii]\n            img = imread(join(vidDir, f))\n            segDir = join(origin_path, d)\n            frameName = str.split(f, '.')[0]\n            anno_rgb = imread(segDir + '/' + frameName + '.png').astype(np.uint32)\n            min_val = np.min(anno_rgb.ravel())\n            max_val = np.max(anno_rgb.ravel())\n            out = (anno_rgb.astype('float') - min_val) / (max_val - min_val)\n            labels = np.zeros((2, img.shape[0], img.shape[1]))\n            labels[1, :, :] = out\n            labels[0, :, :] = 1 - out\n    \n            colors = [0, 255]\n            colorize = np.empty((len(colors), 1), np.uint8)\n            colorize[:,0] = colors\n            n_labels = 2\n    \n            crf = dcrf.DenseCRF(img.shape[1] * img.shape[0], n_labels)\n    \n            U = unary_from_softmax(labels,scale)\n            crf.setUnaryEnergy(U)\n    \n            feats = create_pairwise_gaussian(sdims=(g_dim, g_dim), shape=img.shape[:2])\n    \n            crf.addPairwiseEnergy(feats, compat=g_factor,\n                            kernel=dcrf.DIAG_KERNEL,\n                            normalization=dcrf.NORMALIZE_SYMMETRIC)\n    \n            feats = create_pairwise_bilateral(sdims=(s_dim,s_dim), schan=(C_dim, C_dim, C_dim),# 30,5\n                                          img=img, chdim=2)\n    \n            crf.addPairwiseEnergy(feats, compat=c_factor,\n                            kernel=dcrf.DIAG_KERNEL,\n                            normalization=dcrf.NORMALIZE_SYMMETRIC)\n    \n            #Q = crf.inference(5)\n            Q, tmp1, tmp2 = crf.startInference()\n            for i in range(5):\n                #print(\"KL-divergence at {}: {}\".format(i, crf.klDivergence(Q)))\n                crf.stepInference(Q, tmp1, tmp2)\n    \n            MAP = np.argmax(Q, axis=0)\n            MAP = colorize[MAP]\n            \n            imsave(resDir + '/' + frameName + '.png', MAP.reshape(anno_rgb.shape))\n            print (\"Saving: \" + resDir + '/' + frameName + '.png')\nscales = [1]#[0.5,1]#[0.1,0.3,0.5,0.6]#[0.5, 1.0]\ng_dims = [1]#[1,3]#[1,3]\ng_factors =[5]#[3,5,10] #[ 3, 5,10]\ns_dims = [10,15,20] #[5,10,20]#[11, 12, 13]#[9,10,11] #10\nCs = [7]#[5]#[8]# [ 7,8,9,10] #8\nb_factors = [8,9,10]\nfor scale in scales: \n    for g_dim in g_dims:\n        for ii in range(0,len(g_factors)):\n            g_factor = g_factors[ii]\n            for jj in range(0,len(s_dims)):\n                s_dim = s_dims[jj]\n                for cs in Cs:\n                    p1 = Process(target = worker, args = (scale, g_dim, g_factor,s_dim,cs, b_factors[0]))\n                    p2 = Process(target = worker, args = (scale, g_dim, g_factor,s_dim,cs, b_factors[1]))\n                    p3 = Process(target = worker, args = (scale, g_dim, g_factor,s_dim,cs, b_factors[2]))\n                    #p4 = Process(target = worker, args = (scale, g_dim, g_factor,s_dim,cs, 4))\n                    #p5 = Process(target = worker, args = (scale, g_dim, g_factor,s_dim,cs, 1))\n                    #p6 = Process(target = worker, args = (scale, g_dim, g_factor,s_dim,cs, 1))\n                    \n                    p1.start()\n                    p2.start()\n                    p3.start()\n                    #p4.start()\n                    #p5.start()\n                    #p6.start()\n            \n            \n"
  },
  {
    "path": "pretrained/deep_labv3/readme.md",
    "content": "\n"
  },
  {
    "path": "test_coattention_conf.py",
    "content": "# -*- coding: utf-8 -*-\r\n\"\"\"\r\nCreated on Mon Sep 17 17:53:20 2018\r\n\r\n@author: carri\r\n\"\"\"\r\n\r\nimport argparse\r\nimport torch\r\nimport torch.nn as nn\r\nfrom torch.utils import data\r\nimport numpy as np\r\nimport pickle\r\nimport cv2\r\nfrom torch.autograd import Variable\r\nimport torch.optim as optim\r\nimport scipy.misc\r\nimport torch.backends.cudnn as cudnn\r\nimport sys\r\nimport os\r\nimport os.path as osp\r\nfrom dataloaders import PairwiseImg_test as db\r\n#from dataloaders import StaticImg as db #采用voc dataset的数据设置格式方法\r\nimport matplotlib.pyplot as plt\r\nimport random\r\nimport timeit\r\nfrom PIL import Image\r\nfrom collections import OrderedDict\r\nimport matplotlib.pyplot as plt\r\nimport torch.nn as nn\r\n#from utils.colorize_mask import cityscapes_colorize_mask, VOCColorize\r\n#import pydensecrf.densecrf as dcrf\r\n#from pydensecrf.utils import unary_from_softmax, create_pairwise_bilateral, create_pairwise_gaussian\r\nfrom deeplab.siamese_model_conf import CoattentionNet\r\nfrom torchvision.utils import save_image\r\n\r\ndef get_arguments():\r\n    \"\"\"Parse all the arguments provided from the CLI.\r\n    \r\n    Returns:\r\n      A list of parsed arguments.\r\n    \"\"\"\r\n    parser = argparse.ArgumentParser(description=\"PSPnet\")\r\n    parser.add_argument(\"--dataset\", type=str, default='cityscapes',\r\n                        help=\"voc12, cityscapes, or pascal-context\")\r\n\r\n    # GPU configuration\r\n    parser.add_argument(\"--cuda\", default=True, help=\"Run on CPU or GPU\")\r\n    parser.add_argument(\"--gpus\", type=str, default=\"0\",\r\n                        help=\"choose gpu device.\")\r\n    parser.add_argument(\"--seq_name\", default = 'bmx-bumps')\r\n    parser.add_argument(\"--use_crf\", default = 'True')\r\n    parser.add_argument(\"--sample_range\", default =5)\r\n    \r\n    return parser.parse_args()\r\n\r\ndef configure_dataset_model(args):\r\n    if args.dataset == 'voc12':\r\n        args.data_dir ='/home/wty/AllDataSet/VOC2012'  #Path to the directory containing the PASCAL VOC dataset\r\n        args.data_list = './dataset/list/VOC2012/test.txt'  #Path to the file listing the images in the dataset\r\n        args.img_mean = np.array((104.00698793,116.66876762,122.67891434), dtype=np.float32) \r\n        #RBG mean, first subtract mean and then change to BGR\r\n        args.ignore_label = 255   #The index of the label to ignore during the training\r\n        args.num_classes = 21  #Number of classes to predict (including background)\r\n        args.restore_from = './snapshots/voc12/psp_voc12_14.pth'  #Where restore model parameters from\r\n        args.save_segimage = True\r\n        args.seg_save_dir = \"./result/test/VOC2012\"\r\n        args.corp_size =(505, 505)\r\n        \r\n    elif args.dataset == 'davis': \r\n        args.batch_size = 1# 1 card: 5, 2 cards: 10 Number of images sent to the network in one step, 16 on paper\r\n        args.maxEpoches = 15 # 1 card: 15, 2 cards: 15 epoches, equal to 30k iterations, max iterations= maxEpoches*len(train_aug)/batch_size_per_gpu'),\r\n        args.data_dir = 'your_path/DAVIS-2016'   # 37572 image pairs\r\n        args.data_list = 'your_path/DAVIS-2016/test_seqs.txt'  # Path to the file listing the images in the dataset\r\n        args.ignore_label = 255     #The index of the label to ignore during the training\r\n        args.input_size = '473,473' #Comma-separated string with height and width of images\r\n        args.num_classes = 2      #Number of classes to predict (including background)\r\n        args.img_mean = np.array((104.00698793,116.66876762,122.67891434), dtype=np.float32)       # saving model file and log record during the process of training\r\n        args.restore_from = './your_path.pth' #resnet50-19c8e357.pth''/home/xiankai/PSPNet_PyTorch/snapshots/davis/psp_davis_0.pth' #\r\n        args.snapshot_dir = './snapshots/davis_iteration/'          #Where to save snapshots of the model\r\n        args.save_segimage = True\r\n        args.seg_save_dir = \"./result/test/davis_iteration_conf\"\r\n        args.vis_save_dir = \"./result/test/davis_vis\"\r\n        args.corp_size =(473, 473)\r\n        \r\n    else:\r\n        print(\"dataset error\")\r\n\r\ndef convert_state_dict(state_dict):\r\n    \"\"\"Converts a state dict saved from a dataParallel module to normal \r\n       module state_dict inplace\r\n       :param state_dict is the loaded DataParallel model_state\r\n       You probably saved the model using nn.DataParallel, which stores the model in module, and now you are trying to load it \r\n       without DataParallel. You can either add a nn.DataParallel temporarily in your network for loading purposes, or you can \r\n       load the weights file, create a new ordered dict without the module prefix, and load it back \r\n    \"\"\"\r\n    state_dict_new = OrderedDict()\r\n    #print(type(state_dict))\r\n    for k, v in state_dict.items():\r\n        #print(k)\r\n        name = k[7:] # remove the prefix module.\r\n        # My heart is broken, the pytorch have no ability to do with the problem.\r\n        state_dict_new[name] = v\r\n        if name == 'linear_e.weight':\r\n            np.save('weight_matrix.npy',v.cpu().numpy())\r\n    return state_dict_new\r\n\r\ndef sigmoid(inX): \r\n    return 1.0/(1+np.exp(-inX))#定义一个sigmoid方法，其本质就是1/(1+e^-x)\r\n\r\ndef main():\r\n    args = get_arguments()\r\n    print(\"=====> Configure dataset and model\")\r\n    configure_dataset_model(args)\r\n    print(args)\r\n    model = CoattentionNet(num_classes=args.num_classes)\r\n    \r\n    saved_state_dict = torch.load(args.restore_from, map_location=lambda storage, loc: storage)\r\n    #print(saved_state_dict.keys())\r\n    #model.load_state_dict({k.replace('pspmodule.',''):v for k,v in torch.load(args.restore_from)['state_dict'].items()})\r\n    model.load_state_dict( convert_state_dict(saved_state_dict[\"model\"]) ) #convert_state_dict(saved_state_dict[\"model\"])\r\n\r\n    model.eval()\r\n    model.cuda()\r\n    if args.dataset == 'voc12':\r\n        testloader = data.DataLoader(VOCDataTestSet(args.data_dir, args.data_list, crop_size=(505, 505),mean= args.img_mean), \r\n                                    batch_size=1, shuffle=False, pin_memory=True)\r\n        interp = nn.Upsample(size=(505, 505), mode='bilinear')\r\n        voc_colorize = VOCColorize()\r\n        \r\n    elif args.dataset == 'davis':  #for davis 2016\r\n        db_test = db.PairwiseImg(train=False, inputRes=(473,473), db_root_dir=args.data_dir,  transform=None, seq_name = None, sample_range = args.sample_range) #db_root_dir() --> '/path/to/DAVIS-2016' train path\r\n        testloader = data.DataLoader(db_test, batch_size= 1, shuffle=False, num_workers=0)\r\n        #voc_colorize = VOCColorize()\r\n    else:\r\n        print(\"dataset error\")\r\n\r\n    data_list = []\r\n\r\n    if args.save_segimage:\r\n        if not os.path.exists(args.seg_save_dir) and not os.path.exists(args.vis_save_dir):\r\n            os.makedirs(args.seg_save_dir)\r\n            os.makedirs(args.vis_save_dir)\r\n    print(\"======> test set size:\", len(testloader))\r\n    my_index = 0\r\n    old_temp=''\r\n    for index, batch in enumerate(testloader):\r\n        print('%d processd'%(index))\r\n        target = batch['target']\r\n        #search = batch['search']\r\n        temp = batch['seq_name']\r\n        args.seq_name=temp[0]\r\n        print(args.seq_name)\r\n        if old_temp==args.seq_name:\r\n            my_index = my_index+1\r\n        else:\r\n            my_index = 0\r\n        output_sum = 0   \r\n        for i in range(0,args.sample_range):  \r\n            search = batch['search'+'_'+str(i)]\r\n            search_im = search\r\n            #print(search_im.size())\r\n            output = model(Variable(target, volatile=True).cuda(),Variable(search_im, volatile=True).cuda())\r\n            #print(output[0]) # output有两个\r\n            output_sum = output_sum + output[0].data[0,0].cpu().numpy() #分割那个分支的结果\r\n            #np.save('infer'+str(i)+'.npy',output1)\r\n            #output2 = output[1].data[0, 0].cpu().numpy() #interp'\r\n        \r\n        output1 = output_sum/args.sample_range\r\n     \r\n        first_image = np.array(Image.open(args.data_dir+'/JPEGImages/480p/blackswan/00000.jpg'))\r\n        original_shape = first_image.shape \r\n        output1 = cv2.resize(output1, (original_shape[1],original_shape[0]))\r\n\r\n        mask = (output1*255).astype(np.uint8)\r\n        #print(mask.shape[0])\r\n        mask = Image.fromarray(mask)\r\n        \r\n\r\n        if args.dataset == 'voc12':\r\n            print(output.shape)\r\n            print(size)\r\n            output = output[:,:size[0],:size[1]]\r\n            output = output.transpose(1,2,0)\r\n            output = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)\r\n            if args.save_segimage:\r\n                seg_filename = os.path.join(args.seg_save_dir, '{}.png'.format(name[0]))\r\n                color_file = Image.fromarray(voc_colorize(output).transpose(1, 2, 0), 'RGB')\r\n                color_file.save(seg_filename)\r\n                \r\n        elif args.dataset == 'davis':\r\n            \r\n            save_dir_res = os.path.join(args.seg_save_dir, 'Results', args.seq_name)\r\n            old_temp=args.seq_name\r\n            if not os.path.exists(save_dir_res):\r\n                os.makedirs(save_dir_res)\r\n            if args.save_segimage:   \r\n                my_index1 = str(my_index).zfill(5)\r\n                seg_filename = os.path.join(save_dir_res, '{}.png'.format(my_index1))\r\n                #color_file = Image.fromarray(voc_colorize(output).transpose(1, 2, 0), 'RGB')\r\n                mask.save(seg_filename)\r\n                #np.concatenate((torch.zeros(1, 473, 473), mask, torch.zeros(1, 512, 512)),axis = 0)\r\n                #save_image(output1 * 0.8 + target.data, args.vis_save_dir, normalize=True)\r\n        else:\r\n            print(\"dataset error\")\r\n    \r\n\r\nif __name__ == '__main__':\r\n    main()\r\n"
  },
  {
    "path": "test_iteration_conf_group.py",
    "content": "# -*- coding: utf-8 -*-\r\n\"\"\"\r\nCreated on Mon Sep 17 17:53:20 2018\r\n\r\n@author: carri\r\n\"\"\"\r\n\r\nimport argparse\r\nimport torch\r\nimport torch.nn as nn\r\nfrom torch.utils import data\r\nimport numpy as np\r\nimport pickle\r\nimport cv2\r\nfrom torch.autograd import Variable\r\nimport torch.optim as optim\r\nimport scipy.misc\r\nimport torch.backends.cudnn as cudnn\r\nimport sys\r\nimport os\r\nimport os.path as osp\r\nfrom dataloaders import PairwiseImg_video_test_try as db\r\n#from dataloaders import StaticImg as db #采用voc dataset的数据设置格式方法\r\nimport matplotlib.pyplot as plt\r\nimport random\r\nimport timeit\r\nfrom PIL import Image\r\nfrom collections import OrderedDict\r\nimport matplotlib.pyplot as plt\r\nimport torch.nn as nn\r\nfrom utils.colorize_mask import cityscapes_colorize_mask, VOCColorize\r\n#import pydensecrf.densecrf as dcrf\r\n#from pydensecrf.utils import unary_from_softmax, create_pairwise_bilateral, create_pairwise_gaussian\r\nfrom deeplab.siamese_model_conf_try import CoattentionNet\r\nfrom torchvision.utils import save_image\r\n\r\ndef get_arguments():\r\n    \"\"\"Parse all the arguments provided from the CLI.\r\n    \r\n    Returns:\r\n      A list of parsed arguments.\r\n    \"\"\"\r\n    parser = argparse.ArgumentParser(description=\"PSPnet\")\r\n    parser.add_argument(\"--dataset\", type=str, default='cityscapes',\r\n                        help=\"voc12, cityscapes, or pascal-context\")\r\n\r\n    # GPU configuration\r\n    parser.add_argument(\"--cuda\", default=True, help=\"Run on CPU or GPU\")\r\n    parser.add_argument(\"--gpus\", type=str, default=\"0\",\r\n                        help=\"choose gpu device.\")\r\n    parser.add_argument(\"--seq_name\", default = 'bmx-bumps')\r\n    parser.add_argument(\"--use_crf\", default = 'True')\r\n    parser.add_argument(\"--sample_range\", default =3)\r\n    \r\n    return parser.parse_args()\r\n\r\ndef configure_dataset_model(args):\r\n\r\n    args.batch_size = 1# 1 card: 5, 2 cards: 10 Number of images sent to the network in one step, 16 on paper\r\n    args.maxEpoches = 15 # 1 card: 15, 2 cards: 15 epoches, equal to 30k iterations, max iterations= maxEpoches*len(train_aug)/batch_size_per_gpu'),\r\n    args.data_dir = '/home/xiankai/work/DAVIS-2016'   # 37572 image pairs\r\n    args.data_list = '/home/xiankai/work/DAVIS-2016/test_seqs.txt'  # Path to the file listing the images in the dataset\r\n    args.ignore_label = 255     #The index of the label to ignore during the training\r\n    args.input_size = '473,473' #Comma-separated string with height and width of images\r\n    args.num_classes = 2      #Number of classes to predict (including background)\r\n    args.img_mean = np.array((104.00698793,116.66876762,122.67891434), dtype=np.float32)       # saving model file and log record during the process of training\r\n    args.restore_from = './co_attention_davis_43.pth' #resnet50-19c8e357.pth''/home/xiankai/PSPNet_PyTorch/snapshots/davis/psp_davis_0.pth' #\r\n    args.snapshot_dir = './snapshots/davis_iteration/'          #Where to save snapshots of the model\r\n    args.save_segimage = True\r\n    args.seg_save_dir = \"./result/test/davis_iteration_conf_try\"\r\n    args.vis_save_dir = \"./result/test/davis_vis\"\r\n    args.corp_size =(473, 473)\r\n\r\ndef convert_state_dict(state_dict):\r\n    \"\"\"Converts a state dict saved from a dataParallel module to normal \r\n       module state_dict inplace\r\n       :param state_dict is the loaded DataParallel model_state\r\n       You probably saved the model using nn.DataParallel, which stores the model in module, and now you are trying to load it \r\n       without DataParallel. You can either add a nn.DataParallel temporarily in your network for loading purposes, or you can \r\n       load the weights file, create a new ordered dict without the module prefix, and load it back \r\n    \"\"\"\r\n    state_dict_new = OrderedDict()\r\n    #print(type(state_dict))\r\n    for k, v in state_dict.items():\r\n        #print(k)\r\n        name = k[7:] # remove the prefix module.\r\n        # My heart is broken, the pytorch have no ability to do with the problem.\r\n        state_dict_new[name] = v\r\n        if name == 'linear_e.weight':\r\n            np.save('weight_matrix.npy',v.cpu().numpy())\r\n    return state_dict_new\r\n\r\ndef sigmoid(inX): \r\n    return 1.0/(1+np.exp(-inX))#定义一个sigmoid方法，其本质就是1/(1+e^-x)\r\n\r\ndef main():\r\n    args = get_arguments()\r\n    print(\"=====> Configure dataset and model\")\r\n    configure_dataset_model(args)\r\n    print(args)\r\n\r\n    print(\"=====> Set GPU for training\")\r\n    if args.cuda:\r\n        print(\"====> Use gpu id: '{}'\".format(args.gpus))\r\n        os.environ[\"CUDA_VISIBLE_DEVICES\"] = args.gpus\r\n        if not torch.cuda.is_available():\r\n            raise Exception(\"No GPU found or Wrong gpu id, please run without --cuda\")\r\n    model = CoattentionNet(num_classes=args.num_classes, nframes = args.sample_range)\r\n    for param in model.parameters():\r\n        param.requires_grad = False\r\n\r\n    saved_state_dict = torch.load(args.restore_from, map_location=lambda storage, loc: storage)\r\n    #print(saved_state_dict.keys())\r\n    #model.load_state_dict({k.replace('pspmodule.',''):v for k,v in torch.load(args.restore_from)['state_dict'].items()})\r\n    model.load_state_dict( convert_state_dict(saved_state_dict[\"model\"]) ) #convert_state_dict(saved_state_dict[\"model\"])\r\n\r\n    model.eval()\r\n    model.cuda()\r\n\r\n    db_test = db.PairwiseImg(train=False, inputRes=(473,473), db_root_dir=args.data_dir,  transform=None, seq_name = None, sample_range = args.sample_range) #db_root_dir() --> '/path/to/DAVIS-2016' train path\r\n    testloader = data.DataLoader(db_test, batch_size= 1, shuffle=False, num_workers=0)\r\n    voc_colorize = VOCColorize()\r\n\r\n    data_list = []\r\n\r\n    if args.save_segimage:\r\n        if not os.path.exists(args.seg_save_dir) and not os.path.exists(args.vis_save_dir):\r\n            os.makedirs(args.seg_save_dir)\r\n            #os.makedirs(args.vis_save_dir)\r\n    print(\"======> test set size:\", len(testloader))\r\n    my_index = 0\r\n    old_temp=''\r\n    for index, batch in enumerate(testloader):\r\n        print('%d processd'%(index))\r\n        target = batch['target']\r\n        np.save('target.npy', target.float().data)\r\n        #search = batch['search']\r\n        temp = batch['seq_name']\r\n        args.seq_name=temp[0]\r\n        print(args.seq_name)\r\n        if old_temp==args.seq_name:\r\n            my_index = my_index+1\r\n        else:\r\n            my_index = 0\r\n        output_sum = 0   \r\n        for i in range(0,1):\r\n            search = batch['search'+'_'+str(i)]\r\n            search_im = search\r\n            print('input size:', search_im.size(),len(search.size()))\r\n            if len(search.size()) <5:\r\n                search_im = search_im.unsqueeze(0)\r\n            output = model(Variable(target, volatile=True).cuda(),Variable(search_im, volatile=True).cuda())\r\n            #print(output[0]) # output有两个\r\n            output_sum = output_sum + output[0].data[0,0].cpu().numpy() #分割那个分支的结果\r\n\r\n            #np.save('infer'+str(i)+'.npy',output1)\r\n            #output2 = output[1].data[0, 0].cpu().numpy() #interp'\r\n        \r\n        output1 = output_sum#/args.sample_range\r\n        #target_mask = output[3].data[0,0].cpu().numpy()\r\n        #print('output size:', output1.shape, type(output1))\r\n        first_image = np.array(Image.open(args.data_dir+'/JPEGImages/480p/blackswan/00000.jpg'))\r\n        original_shape = first_image.shape \r\n        output1 = cv2.resize(output1, (original_shape[1],original_shape[0]))\r\n        #output2 = cv2.resize(target_mask, (original_shape[1], original_shape[0]))\r\n        mask = (output1*255).astype(np.uint8)\r\n        #target_mask = (output2*255).astype(np.uint8)\r\n        mask = Image.fromarray(mask)\r\n        #target_mask = Image.fromarray(target_mask)\r\n\r\n        save_dir_res = os.path.join(args.seg_save_dir, 'Results', args.seq_name)\r\n        old_temp=args.seq_name\r\n        if not os.path.exists(save_dir_res):\r\n            os.makedirs(save_dir_res)\r\n        if args.save_segimage:\r\n            my_index1 = str(my_index).zfill(5)\r\n            seg_filename = os.path.join(save_dir_res, '{}.png'.format(my_index1))\r\n            gate_filename = os.path.join(save_dir_res, '{}_gate.png'.format(my_index1))\r\n            mask.save(seg_filename)\r\n            #target_mask.save(gate_filename)\r\n\r\nif __name__ == '__main__':\r\n    main()\r\n"
  },
  {
    "path": "train_iteration_conf.py",
    "content": "# -*- coding: utf-8 -*-\r\n\"\"\"\r\nCreated on Sat Sep 15 10:52:26 2018\r\n\r\n@author: carri\r\n\"\"\"\r\n#区别于deeplab_co_attention_concat在于采用了新的model（siamese_model_concat_new）来train\r\n\r\nimport argparse\r\nimport torch\r\nimport torch.nn as nn\r\nfrom torch.utils import data\r\nimport numpy as np\r\nimport pickle\r\nimport cv2\r\nfrom torch.autograd import Variable\r\nimport torch.optim as optim\r\nimport scipy.misc\r\nimport torch.backends.cudnn as cudnn\r\nimport sys\r\nimport os\r\n#from utils.balanced_BCE import class_balanced_cross_entropy_loss\r\nimport os.path as osp\r\n#from psp.model import PSPNet\r\n#from dataloaders import davis_2016 as db\r\nfrom dataloaders import PairwiseImg_video as db #采用voc dataset的数据设置格式方法\r\nimport matplotlib.pyplot as plt\r\nimport random\r\nimport timeit\r\n#from psp.model1 import CoattentionNet  #基于pspnet搭建的co-attention 模型\r\nfrom deeplab.siamese_model_conf import CoattentionNet #siame_model 是直接将attend的model之后的结果输出\r\n#from deeplab.utils import get_1x_lr_params, get_10x_lr_params#, adjust_learning_rate #, loss_calc\r\nstart = timeit.default_timer()\r\n\r\ndef get_arguments():\r\n    \"\"\"Parse all the arguments provided from the CLI.\r\n    \r\n    Returns:\r\n      A list of parsed arguments.\r\n    \"\"\"\r\n    parser = argparse.ArgumentParser(description=\"PSPnet Network\")\r\n\r\n    # optimatization configuration\r\n    parser.add_argument(\"--is-training\", action=\"store_true\", \r\n                        help=\"Whether to updates the running means and variances during the training.\")\r\n    parser.add_argument(\"--learning-rate\", type=float, default= 0.00025, \r\n                        help=\"Base learning rate for training with polynomial decay.\") #0.001\r\n    parser.add_argument(\"--weight-decay\", type=float, default= 0.0005, \r\n                        help=\"Regularization parameter for L2-loss.\")  # 0.0005\r\n    parser.add_argument(\"--momentum\", type=float, default= 0.9, \r\n                        help=\"Momentum component of the optimiser.\")\r\n    parser.add_argument(\"--power\", type=float, default= 0.9, \r\n                        help=\"Decay parameter to compute the learning rate.\")\r\n    # dataset information\r\n    parser.add_argument(\"--dataset\", type=str, default='cityscapes',\r\n                        help=\"voc12, cityscapes, or pascal-context.\")\r\n    parser.add_argument(\"--random-mirror\", action=\"store_true\",\r\n                        help=\"Whether to randomly mirror the inputs during the training.\")\r\n    parser.add_argument(\"--random-scale\", action=\"store_true\",\r\n                        help=\"Whether to randomly scale the inputs during the training.\")\r\n\r\n    parser.add_argument(\"--not-restore-last\", action=\"store_true\",\r\n                        help=\"Whether to not restore last (FC) layers.\")\r\n    parser.add_argument(\"--random-seed\", type=int, default= 1234,\r\n                        help=\"Random seed to have reproducible results.\")\r\n    parser.add_argument('--logFile', default='log.txt', \r\n                        help='File that stores the training and validation logs')\r\n    # GPU configuration\r\n    parser.add_argument(\"--cuda\", default=True, help=\"Run on CPU or GPU\")\r\n    parser.add_argument(\"--gpus\", type=str, default=\"3\", help=\"choose gpu device.\") #使用3号GPU\r\n\r\n\r\n    return parser.parse_args()\r\n\r\nargs = get_arguments()\r\n\r\n\r\ndef configure_dataset_init_model(args):\r\n    if args.dataset == 'voc12':\r\n\r\n        args.batch_size = 10# 1 card: 5, 2 cards: 10 Number of images sent to the network in one step, 16 on paper\r\n        args.maxEpoches = 15 # 1 card: 15, 2 cards: 15 epoches, equal to 30k iterations, max iterations= maxEpoches*len(train_aug)/batch_size_per_gpu'),\r\n        args.data_dir = '/home/wty/AllDataSet/VOC2012'   # Path to the directory containing the PASCAL VOC dataset\r\n        args.data_list = './dataset/list/VOC2012/train_aug.txt'  # Path to the file listing the images in the dataset\r\n        args.ignore_label = 255     #The index of the label to ignore during the training\r\n        args.input_size = '473,473' #Comma-separated string with height and width of images\r\n        args.num_classes = 21      #Number of classes to predict (including background)\r\n\r\n        args.img_mean = np.array((104.00698793,116.66876762,122.67891434), dtype=np.float32)\r\n        # saving model file and log record during the process of training\r\n\r\n        #Where restore model pretrained on other dataset, such as COCO.\")\r\n        args.restore_from = './pretrained/MS_DeepLab_resnet_pretrained_COCO_init.pth'\r\n        args.snapshot_dir = './snapshots/voc12/'          #Where to save snapshots of the model\r\n        args.resume = './snapshots/voc12/psp_voc12_3.pth' #checkpoint log file, helping recovering training\r\n        \r\n    elif args.dataset == 'davis': \r\n        args.batch_size = 16# 1 card: 5, 2 cards: 10 Number of images sent to the network in one step, 16 on paper\r\n        args.maxEpoches = 60 # 1 card: 15, 2 cards: 15 epoches, equal to 30k iterations, max iterations= maxEpoches*len(train_aug)/batch_size_per_gpu'),\r\n        args.data_dir = '/home/ubuntu/xiankai/dataset/DAVIS-2016'   # 37572 image pairs\r\n        args.img_dir = '/home/ubuntu/xiankai/dataset/images'\r\n        args.data_list = './dataset/list/VOC2012/train_aug.txt'  # Path to the file listing the images in the dataset\r\n        args.ignore_label = 255     #The index of the label to ignore during the training\r\n        args.input_size = '473,473' #Comma-separated string with height and width of images\r\n        args.num_classes = 2      #Number of classes to predict (including background)\r\n        args.img_mean = np.array((104.00698793,116.66876762,122.67891434), dtype=np.float32)       # saving model file and log record during the process of training\r\n        #Where restore model pretrained on other dataset, such as COCO.\")\r\n        args.restore_from = './pretrained/deep_labv3/deeplab_davis_12_0.pth' #resnet50-19c8e357.pth''/home/xiankai/PSPNet_PyTorch/snapshots/davis/psp_davis_0.pth' #\r\n        args.snapshot_dir = './snapshots/davis_iteration_conf/'          #Where to save snapshots of the model\r\n        args.resume = './snapshots/davis/co_attention_davis_124.pth' #checkpoint log file, helping recovering training\r\n\t\t\r\n    elif args.dataset == 'cityscapes':\r\n        args.batch_size = 8   #Number of images sent to the network in one step, batch_size/num_GPU=2\r\n        args.maxEpoches = 60 #epoch nums, 60 epoches is equal to 90k iterations, max iterations= maxEpoches*len(train)/batch_size')\r\n        # 60x2975/2=89250 ~= 90k, single_GPU_batch_size=2\r\n        args.data_dir = '/home/wty/AllDataSet/CityScapes'   # Path to the directory containing the PASCAL VOC dataset\r\n        args.data_list = './dataset/list/Cityscapes/cityscapes_train_list.txt'  # Path to the file listing the images in the dataset\r\n        args.ignore_label = 255     #The index of the label to ignore during the training\r\n        args.input_size = '720,720' #Comma-separated string with height and width of images\r\n        args.num_classes = 19      #Number of classes to predict (including background)\r\n\r\n        args.img_mean = np.array((73.15835921, 82.90891754, 72.39239876), dtype=np.float32)\r\n        # saving model file and log record during the process of training\r\n\r\n        #Where restore model pretrained on other dataset, such as coarse cityscapes\r\n        args.restore_from = './pretrained/resnet101_pretrained_for_cityscapes.pth'\r\n        args.snapshot_dir = './snapshots/cityscapes/'          #Where to save snapshots of the model\r\n        args.resume = './snapshots/cityscapes/psp_cityscapes_12_3.pth' #checkpoint log file, helping recovering training\r\n       \r\n    else:\r\n        print(\"dataset error\")\r\n\r\ndef adjust_learning_rate(optimizer, i_iter, epoch, max_iter):\r\n    \"\"\"Sets the learning rate to the initial LR divided by 5 at 60th, 120th and 160th epochs\"\"\"\r\n    \r\n    lr = lr_poly(args.learning_rate, i_iter, max_iter, args.power, epoch)\r\n    optimizer.param_groups[0]['lr'] = lr\r\n    if i_iter%3 ==0:\r\n        optimizer.param_groups[0]['lr'] = lr\r\n        optimizer.param_groups[1]['lr'] = 0\r\n    else:\r\n        optimizer.param_groups[0]['lr'] = 0.01*lr\r\n        optimizer.param_groups[1]['lr'] = lr * 10\r\n        \r\n    return lr\r\n\r\ndef loss_calc1(pred, label):\r\n    \"\"\"\r\n    This function returns cross entropy loss for semantic segmentation\r\n    \"\"\"\r\n    labels = torch.ge(label, 0.5).float()\r\n#    \r\n    batch_size = label.size()\r\n    #print(batch_size)\r\n    num_labels_pos = torch.sum(labels) \r\n#    \r\n    batch_1 =  batch_size[0]* batch_size[2]\r\n    batch_1 = batch_1* batch_size[3]\r\n    weight_1 = torch.div(num_labels_pos, batch_1) # pos ratio\r\n    weight_1 = torch.reciprocal(weight_1)\r\n    #print(num_labels_pos, batch_1)\r\n    weight_2 = torch.div(batch_1-num_labels_pos, batch_1)\r\n    #print('postive ratio', weight_2, weight_1)\r\n    weight_22 = torch.mul(weight_1,  torch.ones(batch_size[0], batch_size[1], batch_size[2], batch_size[3]).cuda())\r\n    #weight_11 = torch.mul(weight_1,  torch.ones(batch_size[0], batch_size[1], batch_size[2]).cuda())\r\n    criterion = torch.nn.BCELoss(weight = weight_22)#weight = torch.Tensor([0,1]) .cuda() #torch.nn.CrossEntropyLoss(ignore_index=args.ignore_label).cuda()\r\n    #loss = class_balanced_cross_entropy_loss(pred, label).cuda()\r\n        \r\n    return criterion(pred, label)\r\n\r\ndef loss_calc2(pred, label):\r\n    \"\"\"\r\n    This function returns cross entropy loss for semantic segmentation\r\n    \"\"\"\r\n    # out shape batch_size x channels x h x w -> batch_size x channels x h x w\r\n    # label shape h x w x 1 x batch_size  -> batch_size x 1 x h x w\r\n    # Variable(label.long()).cuda()\r\n    criterion = torch.nn.L1Loss()#.cuda() #torch.nn.CrossEntropyLoss(ignore_index=args.ignore_label).cuda()\r\n    \r\n    return criterion(pred, label)\r\n\r\n\r\n\r\ndef get_1x_lr_params(model):\r\n    \"\"\"\r\n    This generator returns all the parameters of the net except for \r\n    the last classification layer. Note that for each batchnorm layer, \r\n    requires_grad is set to False in deeplab_resnet.py, therefore this function does not return \r\n    any batchnorm parameter\r\n    \"\"\"\r\n    b = []\r\n    if torch.cuda.device_count() == 1:\r\n        #b.append(model.encoder.conv1)\r\n        #b.append(model.encoder.bn1)\r\n        #b.append(model.encoder.layer1)\r\n        #b.append(model.encoder.layer2)\r\n        #b.append(model.encoder.layer3)\r\n        #b.append(model.encoder.layer4)\r\n        b.append(model.encoder.layer5)\r\n    else:\r\n        b.append(model.module.encoder.conv1)\r\n        b.append(model.module.encoder.bn1)\r\n        b.append(model.module.encoder.layer1)\r\n        b.append(model.module.encoder.layer2)\r\n        b.append(model.module.encoder.layer3)\r\n        b.append(model.module.encoder.layer4)\r\n        b.append(model.module.encoder.layer5)\r\n        b.append(model.module.encoder.main_classifier)\r\n    for i in range(len(b)):\r\n        for j in b[i].modules():\r\n            jj = 0\r\n            for k in j.parameters():\r\n                jj+=1\r\n                if k.requires_grad:\r\n                    yield k\r\n\r\n\r\ndef get_10x_lr_params(model):\r\n    \"\"\"\r\n    This generator returns all the parameters for the last layer of the net,\r\n    which does the classification of pixel into classes\r\n    \"\"\"\r\n    b = []\r\n    if torch.cuda.device_count() == 1:\r\n        b.append(model.linear_e.parameters())\r\n        b.append(model.main_classifier.parameters())\r\n    else:\r\n        #b.append(model.module.encoder.layer5.parameters())\r\n        b.append(model.module.linear_e.parameters())\r\n        b.append(model.module.conv1.parameters())\r\n        b.append(model.module.conv2.parameters())\r\n        b.append(model.module.gate.parameters())\r\n        b.append(model.module.bn1.parameters())\r\n        b.append(model.module.bn2.parameters())   \r\n        b.append(model.module.main_classifier1.parameters())\r\n        b.append(model.module.main_classifier2.parameters())\r\n        \r\n    for j in range(len(b)):\r\n        for i in b[j]:\r\n            yield i\r\n            \r\ndef lr_poly(base_lr, iter, max_iter, power, epoch):\r\n    if epoch<=2:\r\n        factor = 1\r\n    elif epoch>2 and epoch< 6:\r\n        factor = 1\r\n    else:\r\n        factor = 0.5\r\n    return base_lr*factor*((1-float(iter)/max_iter)**(power))\r\n\r\n\r\ndef netParams(model):\r\n    '''\r\n    Computing total network parameters\r\n    Args:\r\n       model: model\r\n    return: total network parameters\r\n    '''\r\n    total_paramters = 0\r\n    for parameter in model.parameters():\r\n        i = len(parameter.size())\r\n        #print(parameter.size())\r\n        p = 1\r\n        for j in range(i):\r\n            p *= parameter.size(j)\r\n        total_paramters += p\r\n\r\n    return total_paramters\r\n\r\ndef main():\r\n    \r\n    \r\n    print(\"=====> Configure dataset and pretrained model\")\r\n    configure_dataset_init_model(args)\r\n    print(args)\r\n\r\n    print(\"    current dataset:  \", args.dataset)\r\n    print(\"    init model: \", args.restore_from)\r\n    print(\"=====> Set GPU for training\")\r\n    if args.cuda:\r\n        print(\"====> Use gpu id: '{}'\".format(args.gpus))\r\n        os.environ[\"CUDA_VISIBLE_DEVICES\"] = args.gpus\r\n        if not torch.cuda.is_available():\r\n            raise Exception(\"No GPU found or Wrong gpu id, please run without --cuda\")\r\n    # Select which GPU, -1 if CPU\r\n    #gpu_id = args.gpus\r\n    #device = torch.device(\"cuda:\"+str(gpu_id) if torch.cuda.is_available() else \"cpu\")\r\n    print(\"=====> Random Seed: \", args.random_seed)\r\n    torch.manual_seed(args.random_seed)\r\n    if args.cuda:\r\n        torch.cuda.manual_seed(args.random_seed) \r\n\r\n    h, w = map(int, args.input_size.split(','))\r\n    input_size = (h, w)\r\n\r\n    cudnn.enabled = True\r\n\r\n    print(\"=====> Building network\")\r\n    saved_state_dict = torch.load(args.restore_from)\r\n    model = CoattentionNet(num_classes=args.num_classes)\r\n    #print(model)\r\n    new_params = model.state_dict().copy()\r\n    for i in saved_state_dict[\"model\"]:\r\n        #Scale.layer5.conv2d_list.3.weight\r\n        i_parts = i.split('.') # 针对多GPU的情况\r\n        #i_parts.pop(1)\r\n        #print('i_parts:  ', '.'.join(i_parts[1:-1]))\r\n        #if  not i_parts[1]=='main_classifier': #and not '.'.join(i_parts[1:-1]) == 'layer5.bottleneck' and not '.'.join(i_parts[1:-1]) == 'layer5.bn':  #init model pretrained on COCO, class name=21, layer5 is ASPP\r\n        new_params['encoder'+'.'+'.'.join(i_parts[1:])] = saved_state_dict[\"model\"][i]\r\n            #print('copy {}'.format('.'.join(i_parts[1:])))\r\n    \r\n   \r\n    print(\"=====> Loading init weights,  pretrained COCO for VOC2012, and pretrained Coarse cityscapes for cityscapes\")\r\n \r\n            \r\n    model.load_state_dict(new_params) #只用到resnet的第5个卷积层的参数\r\n    #print(model.keys())\r\n    if args.cuda:\r\n        #model.to(device)\r\n        if torch.cuda.device_count()>1:\r\n            print(\"torch.cuda.device_count()=\",torch.cuda.device_count())\r\n            model = torch.nn.DataParallel(model).cuda()  #multi-card data parallel\r\n        else:\r\n            print(\"single GPU for training\")\r\n            model = model.cuda()  #1-card data parallel\r\n    start_epoch=0\r\n    \r\n    print(\"=====> Whether resuming from a checkpoint, for continuing training\")\r\n    if args.resume:\r\n        if os.path.isfile(args.resume):\r\n            print(\"=> loading checkpoint '{}'\".format(args.resume))\r\n            checkpoint = torch.load(args.resume)\r\n            start_epoch = checkpoint[\"epoch\"] \r\n            model.load_state_dict(checkpoint[\"model\"])\r\n        else:\r\n            print(\"=> no checkpoint found at '{}'\".format(args.resume))\r\n\r\n\r\n    model.train()\r\n    cudnn.benchmark = True\r\n\r\n    if not os.path.exists(args.snapshot_dir):\r\n        os.makedirs(args.snapshot_dir)\r\n    \r\n    print('=====> Computing network parameters')\r\n    total_paramters = netParams(model)\r\n    print('Total network parameters: ' + str(total_paramters))\r\n \r\n    print(\"=====> Preparing training data\")\r\n    if args.dataset == 'voc12':\r\n        trainloader = data.DataLoader(VOCDataSet(args.data_dir, args.data_list, max_iters=None, crop_size=input_size, \r\n                                                 scale=args.random_scale, mirror=args.random_mirror, mean=args.img_mean), \r\n                                      batch_size= args.batch_size, shuffle=True, num_workers=0, pin_memory=True, drop_last=True)\r\n    elif args.dataset == 'cityscapes':\r\n        trainloader = data.DataLoader(CityscapesDataSet(args.data_dir, args.data_list, max_iters=None, crop_size=input_size, \r\n                                                 scale=args.random_scale, mirror=args.random_mirror, mean=args.img_mean), \r\n                                      batch_size = args.batch_size, shuffle=True, num_workers=0, pin_memory=True, drop_last=True)\r\n    elif args.dataset == 'davis':  #for davis 2016\r\n        db_train = db.PairwiseImg(train=True, inputRes=input_size, db_root_dir=args.data_dir, img_root_dir=args.img_dir,  transform=None) #db_root_dir() --> '/path/to/DAVIS-2016' train path\r\n        trainloader = data.DataLoader(db_train, batch_size= args.batch_size, shuffle=True, num_workers=0)\r\n    else:\r\n        print(\"dataset error\")\r\n\r\n    optimizer = optim.SGD([{'params': get_1x_lr_params(model), 'lr': 1*args.learning_rate },  #针对特定层进行学习，有些层不学习\r\n                {'params': get_10x_lr_params(model), 'lr': 10*args.learning_rate}], \r\n                lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay)\r\n    optimizer.zero_grad()\r\n\r\n\r\n    \r\n    logFileLoc = args.snapshot_dir + args.logFile\r\n    if os.path.isfile(logFileLoc):\r\n        logger = open(logFileLoc, 'a')\r\n    else:\r\n        logger = open(logFileLoc, 'w')\r\n        logger.write(\"Parameters: %s\" % (str(total_paramters)))\r\n        logger.write(\"\\n%s\\t\\t%s\" % ('iter', 'Loss(train)\\n'))\r\n    logger.flush()\r\n\r\n    print(\"=====> Begin to train\")\r\n    train_len=len(trainloader)\r\n    print(\"  iteration numbers  of per epoch: \", train_len)\r\n    print(\"  epoch num: \", args.maxEpoches)\r\n    print(\"  max iteration: \", args.maxEpoches*train_len)\r\n    \r\n    for epoch in range(start_epoch, int(args.maxEpoches)):\r\n        \r\n        np.random.seed(args.random_seed + epoch)\r\n        for i_iter, batch in enumerate(trainloader,0): #i_iter from 0 to len-1\r\n            #print(\"i_iter=\", i_iter, \"epoch=\", epoch)\r\n            target, target_gt, search, search_gt = batch['target'], batch['target_gt'], batch['search'], batch['search_gt']\r\n            images, labels = batch['img'], batch['img_gt']\r\n            #print(labels.size())\r\n            images.requires_grad_()\r\n            images = Variable(images).cuda()\r\n            labels = Variable(labels.float().unsqueeze(1)).cuda()\r\n            \r\n            target.requires_grad_()\r\n            target = Variable(target).cuda()\r\n            target_gt = Variable(target_gt.float().unsqueeze(1)).cuda()\r\n            \r\n            search.requires_grad_()\r\n            search = Variable(search).cuda()\r\n            search_gt = Variable(search_gt.float().unsqueeze(1)).cuda()\r\n            \r\n            optimizer.zero_grad()\r\n            \r\n            lr = adjust_learning_rate(optimizer, i_iter+epoch*train_len, epoch,\r\n                    max_iter = args.maxEpoches * train_len)\r\n            #print(images.size())\r\n            if i_iter%3 ==0: #对于静态图片的训练\r\n                \r\n                pred1, pred2, pred3 = model(images, images)\r\n                loss = 0.1*(loss_calc1(pred3, labels) + 0.8* loss_calc2(pred3, labels) )\r\n                loss.backward()\r\n                \r\n            else:\r\n                    \r\n                pred1, pred2, pred3 = model(target, search)\r\n                loss = loss_calc1(pred1, target_gt) + 0.8* loss_calc2(pred1, target_gt) + loss_calc1(pred2, search_gt) + 0.8* loss_calc2(pred2, search_gt)#class_balanced_cross_entropy_loss(pred, labels, size_average=False)\r\n                loss.backward()\r\n            \r\n            optimizer.step()\r\n                \r\n            print(\"===> Epoch[{}]({}/{}): Loss: {:.10f}  lr: {:.5f}\".format(epoch, i_iter, train_len, loss.data, lr))\r\n            logger.write(\"Epoch[{}]({}/{}):     Loss: {:.10f}      lr: {:.5f}\\n\".format(epoch, i_iter, train_len, loss.data, lr))\r\n            logger.flush()\r\n                \r\n        print(\"=====> saving model\")\r\n        state={\"epoch\": epoch+1, \"model\": model.state_dict()}\r\n        torch.save(state, osp.join(args.snapshot_dir, 'co_attention_'+str(args.dataset)+\"_\"+str(epoch)+'.pth'))\r\n\r\n\r\n    end = timeit.default_timer()\r\n    print( float(end-start)/3600, 'h')\r\n    logger.write(\"total training time: {:.2f} h\\n\".format(float(end-start)/3600))\r\n    logger.close()\r\n\r\n\r\nif __name__ == '__main__':\r\n    main()\r\n"
  },
  {
    "path": "train_iteration_conf_group.py",
    "content": "# -*- coding: utf-8 -*-\r\n\"\"\"\r\nCreated on Sat Sep 15 10:52:26 2018\r\n\r\n@author: carri\r\n\"\"\"\r\n#区别于deeplab_co_attention_concat在于采用了新的model（siamese_model_concat_new）来train\r\n\r\nimport argparse\r\nimport torch\r\nimport torch.nn as nn\r\nfrom torch.utils import data\r\nimport numpy as np\r\nimport pickle\r\nimport cv2\r\nfrom torch.autograd import Variable\r\nimport torch.optim as optim\r\nimport scipy.misc\r\nimport torch.backends.cudnn as cudnn\r\nimport sys\r\nimport os\r\nfrom utils.balanced_BCE import class_balanced_cross_entropy_loss\r\nimport os.path as osp\r\n#from psp.model import PSPNet\r\n#from dataloaders import davis_2016 as db\r\nfrom dataloaders import PairwiseImg_video_try as db #采用voc dataset的数据设置格式方法\r\nimport matplotlib.pyplot as plt\r\nimport random\r\nimport timeit\r\n#from psp.model1 import CoattentionNet  #基于pspnet搭建的co-attention 模型\r\nfrom deeplab.siamese_model_conf_try import CoattentionNet #siame_model 是直接将attend的model之后的结果输出\r\n#from deeplab.utils import get_1x_lr_params, get_10x_lr_params#, adjust_learning_rate #, loss_calc\r\nstart = timeit.default_timer()\r\n\r\ndef get_arguments():\r\n    \"\"\"Parse all the arguments provided from the CLI.\r\n    \r\n    Returns:\r\n      A list of parsed arguments.\r\n    \"\"\"\r\n    parser = argparse.ArgumentParser(description=\"PSPnet Network\")\r\n\r\n    # optimatization configuration\r\n    parser.add_argument(\"--is-training\", action=\"store_true\", \r\n                        help=\"Whether to updates the running means and variances during the training.\")\r\n    parser.add_argument(\"--learning-rate\", type=float, default= 0.00025, \r\n                        help=\"Base learning rate for training with polynomial decay.\") #0.001\r\n    parser.add_argument(\"--weight-decay\", type=float, default= 0.0005, \r\n                        help=\"Regularization parameter for L2-loss.\")  # 0.0005\r\n    parser.add_argument(\"--momentum\", type=float, default= 0.9, \r\n                        help=\"Momentum component of the optimiser.\")\r\n    parser.add_argument(\"--power\", type=float, default= 0.9, \r\n                        help=\"Decay parameter to compute the learning rate.\")\r\n    # dataset information\r\n    parser.add_argument(\"--dataset\", type=str, default='cityscapes',\r\n                        help=\"voc12, cityscapes, or pascal-context.\")\r\n    parser.add_argument(\"--random-mirror\", action=\"store_true\",\r\n                        help=\"Whether to randomly mirror the inputs during the training.\")\r\n    parser.add_argument(\"--random-scale\", action=\"store_true\",\r\n                        help=\"Whether to randomly scale the inputs during the training.\")\r\n\r\n    parser.add_argument(\"--not-restore-last\", action=\"store_true\",\r\n                        help=\"Whether to not restore last (FC) layers.\")\r\n    parser.add_argument(\"--random-seed\", type=int, default= 1234,\r\n                        help=\"Random seed to have reproducible results.\")\r\n    parser.add_argument('--logFile', default='log.txt', \r\n                        help='File that stores the training and validation logs')\r\n    # GPU configuration\r\n    parser.add_argument(\"--cuda\", default=True, help=\"Run on CPU or GPU\")\r\n    parser.add_argument(\"--gpus\", type=str, default=\"3\", help=\"choose gpu device.\") #使用3号GPU\r\n\r\n\r\n    return parser.parse_args()\r\n\r\nargs = get_arguments()\r\n\r\n\r\ndef configure_dataset_init_model(args):\r\n    if args.dataset == 'voc12':\r\n\r\n        args.batch_size = 10# 1 card: 5, 2 cards: 10 Number of images sent to the network in one step, 16 on paper\r\n        args.maxEpoches = 15 # 1 card: 15, 2 cards: 15 epoches, equal to 30k iterations, max iterations= maxEpoches*len(train_aug)/batch_size_per_gpu'),\r\n        args.data_dir = '/home/wty/AllDataSet/VOC2012'   # Path to the directory containing the PASCAL VOC dataset\r\n        args.data_list = './dataset/list/VOC2012/train_aug.txt'  # Path to the file listing the images in the dataset\r\n        args.ignore_label = 255     #The index of the label to ignore during the training\r\n        args.input_size = '473,473' #Comma-separated string with height and width of images\r\n        args.num_classes = 21      #Number of classes to predict (including background)\r\n\r\n        args.img_mean = np.array((104.00698793,116.66876762,122.67891434), dtype=np.float32)\r\n        # saving model file and log record during the process of training\r\n\r\n        #Where restore model pretrained on other dataset, such as COCO.\")\r\n        args.restore_from = './pretrained/MS_DeepLab_resnet_pretrained_COCO_init.pth'\r\n        args.snapshot_dir = './snapshots/voc12/'          #Where to save snapshots of the model\r\n        args.resume = './snapshots/voc12/psp_voc12_3.pth' #checkpoint log file, helping recovering training\r\n        \r\n    elif args.dataset == 'davis': \r\n        args.batch_size = 16# 1 card: 5, 2 cards: 10 Number of images sent to the network in one step, 16 on paper\r\n        args.maxEpoches = 60 # 1 card: 15, 2 cards: 15 epoches, equal to 30k iterations, max iterations= maxEpoches*len(train_aug)/batch_size_per_gpu'),\r\n        args.data_dir = '/home/ubuntu/xiankai/dataset/DAVIS-2016'   # 37572 image pairs\r\n        args.img_dir = '/home/ubuntu/xiankai/dataset/images'\r\n        args.data_list = './dataset/list/VOC2012/train_aug.txt'  # Path to the file listing the images in the dataset\r\n        args.ignore_label = 255     #The index of the label to ignore during the training\r\n        args.input_size = '378, 378' #Comma-separated string with height and width of images\r\n        args.num_classes = 2      #Number of classes to predict (including background)\r\n        args.img_mean = np.array((104.00698793,116.66876762,122.67891434), dtype=np.float32)       # saving model file and log record during the process of training\r\n        #Where restore model pretrained on other dataset, such as COCO.\")\r\n        args.restore_from = './pretrained/deep_labv3/deeplab_davis_12_0.pth' #resnet50-19c8e357.pth''/home/xiankai/PSPNet_PyTorch/snapshots/davis/psp_davis_0.pth' #\r\n        args.snapshot_dir = './snapshots/davis_iteration_conf_try/'          #Where to save snapshots of the model\r\n        args.resume = './snapshots/davis/co_attention_davis_124.pth' #checkpoint log file, helping recovering training\r\n\t\t\r\n    elif args.dataset == 'cityscapes':\r\n        args.batch_size = 8   #Number of images sent to the network in one step, batch_size/num_GPU=2\r\n        args.maxEpoches = 60 #epoch nums, 60 epoches is equal to 90k iterations, max iterations= maxEpoches*len(train)/batch_size')\r\n        # 60x2975/2=89250 ~= 90k, single_GPU_batch_size=2\r\n        args.data_dir = '/home/wty/AllDataSet/CityScapes'   # Path to the directory containing the PASCAL VOC dataset\r\n        args.data_list = './dataset/list/Cityscapes/cityscapes_train_list.txt'  # Path to the file listing the images in the dataset\r\n        args.ignore_label = 255     #The index of the label to ignore during the training\r\n        args.input_size = '720,720' #Comma-separated string with height and width of images\r\n        args.num_classes = 19      #Number of classes to predict (including background)\r\n\r\n        args.img_mean = np.array((73.15835921, 82.90891754, 72.39239876), dtype=np.float32)\r\n        # saving model file and log record during the process of training\r\n\r\n        #Where restore model pretrained on other dataset, such as coarse cityscapes\r\n        args.restore_from = './pretrained/resnet101_pretrained_for_cityscapes.pth'\r\n        args.snapshot_dir = './snapshots/cityscapes/'          #Where to save snapshots of the model\r\n        args.resume = './snapshots/cityscapes/psp_cityscapes_12_3.pth' #checkpoint log file, helping recovering training\r\n       \r\n    else:\r\n        print(\"dataset error\")\r\n\r\ndef adjust_learning_rate(optimizer, i_iter, epoch, max_iter):\r\n    \"\"\"Sets the learning rate to the initial LR divided by 5 at 60th, 120th and 160th epochs\"\"\"\r\n    \r\n    lr = lr_poly(args.learning_rate, i_iter, max_iter, args.power, epoch)\r\n    optimizer.param_groups[0]['lr'] = lr\r\n    if i_iter%3 ==0:\r\n        optimizer.param_groups[0]['lr'] = lr\r\n        optimizer.param_groups[1]['lr'] = 0\r\n    else:\r\n        optimizer.param_groups[0]['lr'] = 0.01*lr\r\n        optimizer.param_groups[1]['lr'] = lr * 10\r\n        \r\n    return lr\r\n\r\ndef loss_calc1(pred, label):\r\n    \"\"\"\r\n    This function returns cross entropy loss for semantic segmentation\r\n    \"\"\"\r\n    labels = torch.ge(label, 0.5).float()\r\n#    \r\n    batch_size = label.size()\r\n    #print(batch_size)\r\n    num_labels_pos = torch.sum(labels) \r\n#    \r\n    batch_1 =  batch_size[0]* batch_size[2]\r\n    batch_1 = batch_1* batch_size[3]\r\n    weight_1 = torch.div(num_labels_pos, batch_1) # pos ratio\r\n    weight_1 = torch.reciprocal(weight_1)\r\n    #print(num_labels_pos, batch_1)\r\n    weight_2 = torch.div(batch_1-num_labels_pos, batch_1)\r\n    #print('postive ratio', weight_2, weight_1)\r\n    weight_22 = torch.mul(weight_1,  torch.ones(batch_size[0], batch_size[1], batch_size[2], batch_size[3]).cuda())\r\n    #weight_11 = torch.mul(weight_1,  torch.ones(batch_size[0], batch_size[1], batch_size[2]).cuda())\r\n    criterion = torch.nn.BCELoss(weight = weight_22)#weight = torch.Tensor([0,1]) .cuda() #torch.nn.CrossEntropyLoss(ignore_index=args.ignore_label).cuda()\r\n    #loss = class_balanced_cross_entropy_loss(pred, label).cuda()\r\n        \r\n    return criterion(pred, label)\r\n\r\ndef loss_calc2(pred, label):\r\n    \"\"\"\r\n    This function returns cross entropy loss for semantic segmentation\r\n    \"\"\"\r\n    # out shape batch_size x channels x h x w -> batch_size x channels x h x w\r\n    # label shape h x w x 1 x batch_size  -> batch_size x 1 x h x w\r\n    # Variable(label.long()).cuda()\r\n    criterion = torch.nn.L1Loss()#.cuda() #torch.nn.CrossEntropyLoss(ignore_index=args.ignore_label).cuda()\r\n    \r\n    return criterion(pred, label)\r\n\r\n\r\n\r\ndef get_1x_lr_params(model):\r\n    \"\"\"\r\n    This generator returns all the parameters of the net except for \r\n    the last classification layer. Note that for each batchnorm layer, \r\n    requires_grad is set to False in deeplab_resnet.py, therefore this function does not return \r\n    any batchnorm parameter\r\n    \"\"\"\r\n    b = []\r\n    if torch.cuda.device_count() == 1:\r\n        #b.append(model.encoder.conv1)\r\n        #b.append(model.encoder.bn1)\r\n        #b.append(model.encoder.layer1)\r\n        #b.append(model.encoder.layer2)\r\n        #b.append(model.encoder.layer3)\r\n        #b.append(model.encoder.layer4)\r\n        b.append(model.encoder.layer5)\r\n    else:\r\n        b.append(model.module.encoder.conv1)\r\n        b.append(model.module.encoder.bn1)\r\n        b.append(model.module.encoder.layer1)\r\n        b.append(model.module.encoder.layer2)\r\n        b.append(model.module.encoder.layer3)\r\n        b.append(model.module.encoder.layer4)\r\n        b.append(model.module.encoder.layer5)\r\n        b.append(model.module.encoder.main_classifier)\r\n    for i in range(len(b)):\r\n        for j in b[i].modules():\r\n            jj = 0\r\n            for k in j.parameters():\r\n                jj+=1\r\n                if k.requires_grad:\r\n                    yield k\r\n\r\n\r\ndef get_10x_lr_params(model):\r\n    \"\"\"\r\n    This generator returns all the parameters for the last layer of the net,\r\n    which does the classification of pixel into classes\r\n    \"\"\"\r\n    b = []\r\n    if torch.cuda.device_count() == 1:\r\n        b.append(model.linear_e.parameters())\r\n        b.append(model.main_classifier.parameters())\r\n    else:\r\n        #b.append(model.module.encoder.layer5.parameters())\r\n        b.append(model.module.linear_e.parameters())\r\n        b.append(model.module.conv1.parameters())\r\n        #b.append(model.module.conv2.parameters())\r\n        b.append(model.module.gate.parameters())\r\n        b.append(model.module.bn1.parameters())\r\n        #b.append(model.module.bn2.parameters())\r\n        b.append(model.module.main_classifier1.parameters())\r\n        #b.append(model.module.main_classifier2.parameters())\r\n        \r\n    for j in range(len(b)):\r\n        for i in b[j]:\r\n            yield i\r\n            \r\ndef lr_poly(base_lr, iter, max_iter, power, epoch):\r\n    if epoch<=2:\r\n        factor = 1\r\n    elif epoch>2 and epoch< 6:\r\n        factor = 1\r\n    else:\r\n        factor = 0.5\r\n    return base_lr*factor*((1-float(iter)/max_iter)**(power))\r\n\r\n\r\ndef netParams(model):\r\n    '''\r\n    Computing total network parameters\r\n    Args:\r\n       model: model\r\n    return: total network parameters\r\n    '''\r\n    total_paramters = 0\r\n    for parameter in model.parameters():\r\n        i = len(parameter.size())\r\n        #print(parameter.size())\r\n        p = 1\r\n        for j in range(i):\r\n            p *= parameter.size(j)\r\n        total_paramters += p\r\n\r\n    return total_paramters\r\n\r\ndef main():\r\n    \r\n    \r\n    print(\"=====> Configure dataset and pretrained model\")\r\n    configure_dataset_init_model(args)\r\n    print(args)\r\n\r\n    print(\"    current dataset:  \", args.dataset)\r\n    print(\"    init model: \", args.restore_from)\r\n    print(\"=====> Set GPU for training\")\r\n    if args.cuda:\r\n        print(\"====> Use gpu id: '{}'\".format(args.gpus))\r\n        os.environ[\"CUDA_VISIBLE_DEVICES\"] = args.gpus\r\n        if not torch.cuda.is_available():\r\n            raise Exception(\"No GPU found or Wrong gpu id, please run without --cuda\")\r\n    # Select which GPU, -1 if CPU\r\n    #gpu_id = args.gpus\r\n    #device = torch.device(\"cuda:\"+str(gpu_id) if torch.cuda.is_available() else \"cpu\")\r\n    print(\"=====> Random Seed: \", args.random_seed)\r\n    torch.manual_seed(args.random_seed)\r\n    if args.cuda:\r\n        torch.cuda.manual_seed(args.random_seed) \r\n\r\n    h, w = map(int, args.input_size.split(','))\r\n    input_size = (h, w)\r\n\r\n    cudnn.enabled = True\r\n\r\n    print(\"=====> Building network\")\r\n    saved_state_dict = torch.load(args.restore_from)\r\n    model = CoattentionNet(num_classes=args.num_classes)\r\n    #print(model)\r\n    new_params = model.state_dict().copy()\r\n    for i in saved_state_dict[\"model\"]:\r\n        #Scale.layer5.conv2d_list.3.weight\r\n        i_parts = i.split('.') # 针对多GPU的情况\r\n        #i_parts.pop(1)\r\n        #print('i_parts:  ', '.'.join(i_parts[1:-1]))\r\n        #if  not i_parts[1]=='main_classifier': #and not '.'.join(i_parts[1:-1]) == 'layer5.bottleneck' and not '.'.join(i_parts[1:-1]) == 'layer5.bn':  #init model pretrained on COCO, class name=21, layer5 is ASPP\r\n        new_params['encoder'+'.'+'.'.join(i_parts[1:])] = saved_state_dict[\"model\"][i]\r\n            #print('copy {}'.format('.'.join(i_parts[1:])))\r\n    \r\n   \r\n    print(\"=====> Loading init weights,  pretrained COCO for VOC2012, and pretrained Coarse cityscapes for cityscapes\")\r\n \r\n            \r\n    model.load_state_dict(new_params) #只用到resnet的第5个卷积层的参数\r\n    #print(model.keys())\r\n    if args.cuda:\r\n        #model.to(device)\r\n        if torch.cuda.device_count()>1:\r\n            print(\"torch.cuda.device_count()=\",torch.cuda.device_count())\r\n            model = torch.nn.DataParallel(model).cuda()  #multi-card data parallel\r\n        else:\r\n            print(\"single GPU for training\")\r\n            model = model.cuda()  #1-card data parallel\r\n    start_epoch=0\r\n    \r\n    print(\"=====> Whether resuming from a checkpoint, for continuing training\")\r\n    if args.resume:\r\n        if os.path.isfile(args.resume):\r\n            print(\"=> loading checkpoint '{}'\".format(args.resume))\r\n            checkpoint = torch.load(args.resume)\r\n            start_epoch = checkpoint[\"epoch\"] \r\n            model.load_state_dict(checkpoint[\"model\"])\r\n        else:\r\n            print(\"=> no checkpoint found at '{}'\".format(args.resume))\r\n\r\n\r\n    model.train()\r\n    cudnn.benchmark = True\r\n\r\n    if not os.path.exists(args.snapshot_dir):\r\n        os.makedirs(args.snapshot_dir)\r\n    \r\n    print('=====> Computing network parameters')\r\n    total_paramters = netParams(model)\r\n    print('Total network parameters: ' + str(total_paramters))\r\n \r\n    print(\"=====> Preparing training data\")\r\n    if args.dataset == 'voc12':\r\n        trainloader = data.DataLoader(VOCDataSet(args.data_dir, args.data_list, max_iters=None, crop_size=input_size, \r\n                                                 scale=args.random_scale, mirror=args.random_mirror, mean=args.img_mean), \r\n                                      batch_size= args.batch_size, shuffle=True, num_workers=0, pin_memory=True, drop_last=True)\r\n    elif args.dataset == 'cityscapes':\r\n        trainloader = data.DataLoader(CityscapesDataSet(args.data_dir, args.data_list, max_iters=None, crop_size=input_size, \r\n                                                 scale=args.random_scale, mirror=args.random_mirror, mean=args.img_mean), \r\n                                      batch_size = args.batch_size, shuffle=True, num_workers=0, pin_memory=True, drop_last=True)\r\n    elif args.dataset == 'davis':  #for davis 2016\r\n        db_train = db.PairwiseImg(train=True, inputRes=input_size, db_root_dir=args.data_dir, img_root_dir=args.img_dir,  transform=None) #db_root_dir() --> '/path/to/DAVIS-2016' train path\r\n        trainloader = data.DataLoader(db_train, batch_size= args.batch_size, shuffle=True, num_workers=0)\r\n    else:\r\n        print(\"dataset error\")\r\n\r\n    optimizer = optim.SGD([{'params': get_1x_lr_params(model), 'lr': 1*args.learning_rate },  #针对特定层进行学习，有些层不学习\r\n                {'params': get_10x_lr_params(model), 'lr': 10*args.learning_rate}], \r\n                lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay)\r\n    optimizer.zero_grad()\r\n\r\n\r\n    \r\n    logFileLoc = args.snapshot_dir + args.logFile\r\n    if os.path.isfile(logFileLoc):\r\n        logger = open(logFileLoc, 'a')\r\n    else:\r\n        logger = open(logFileLoc, 'w')\r\n        logger.write(\"Parameters: %s\" % (str(total_paramters)))\r\n        logger.write(\"\\n%s\\t\\t%s\" % ('iter', 'Loss(train)\\n'))\r\n    logger.flush()\r\n\r\n    print(\"=====> Begin to train\")\r\n    train_len=len(trainloader)\r\n    print(\"  iteration numbers  of per epoch: \", train_len)\r\n    print(\"  epoch num: \", args.maxEpoches)\r\n    print(\"  max iteration: \", args.maxEpoches*train_len)\r\n    \r\n    for epoch in range(start_epoch, int(args.maxEpoches)):\r\n        \r\n        np.random.seed(args.random_seed + epoch)\r\n        for i_iter, batch in enumerate(trainloader,0): #i_iter from 0 to len-1\r\n            #print(\"i_iter=\", i_iter, \"epoch=\", epoch)\r\n            target, target_gt, search, search_gt = batch['target'], batch['target_grt'], batch['search'], batch['search_grt']\r\n            images, labels = batch['img'], batch['img_grt']\r\n            #print('input size:', len(target), target.size(),labels.size())\r\n            #8,2,3,473,473\r\n            images.requires_grad_()\r\n            images = Variable(images).cuda()\r\n            labels = Variable(labels.float().unsqueeze(1)).cuda()\r\n            \r\n            target.requires_grad_()\r\n            target = Variable(target).cuda()\r\n            target_gt = Variable(target_gt.float().unsqueeze(1)).cuda()\r\n            \r\n            search.requires_grad_()\r\n            search = Variable(search).cuda()\r\n            search_gt = Variable(search_gt.float().unsqueeze(1)).cuda()\r\n            \r\n            optimizer.zero_grad()\r\n            \r\n            lr = adjust_learning_rate(optimizer, i_iter+epoch*train_len, epoch,\r\n                    max_iter = args.maxEpoches * train_len)\r\n            #print(images.size())\r\n            if i_iter%3 ==0: #对于静态图片的训练\r\n                \r\n                pred1, pred2  = model(images, images)\r\n                loss = 0.1*(loss_calc1(pred2, labels) + 0.8* loss_calc2(pred2, labels))\r\n                loss.backward()\r\n                \r\n            else:\r\n                    \r\n                pred1, pred2 = model(target, search)\r\n                #print('video prediction size:', pred2.size(),target_gt.size())\r\n                loss = loss_calc1(pred1, target_gt) + 0.8* loss_calc2(pred1, target_gt)\r\n                loss.backward()\r\n            \r\n            optimizer.step()\r\n                \r\n            print(\"===> Epoch[{}]({}/{}): Loss: {:.10f}  lr: {:.5f}\".format(epoch, i_iter, train_len, loss.data, lr))\r\n            logger.write(\"Epoch[{}]({}/{}):     Loss: {:.10f}      lr: {:.5f}\\n\".format(epoch, i_iter, train_len, loss.data, lr))\r\n            logger.flush()\r\n                \r\n        print(\"=====> saving model\")\r\n        state={\"epoch\": epoch+1, \"model\": model.state_dict()}\r\n        torch.save(state, osp.join(args.snapshot_dir, 'co_attention_'+str(args.dataset)+\"_\"+str(epoch)+'.pth'))\r\n\r\n\r\n    end = timeit.default_timer()\r\n    print( float(end-start)/3600, 'h')\r\n    logger.write(\"total training time: {:.2f} h\\n\".format(float(end-start)/3600))\r\n    logger.close()\r\n\r\n\r\nif __name__ == '__main__':\r\n    main()\r\n"
  }
]