[
  {
    "path": ".gitignore",
    "content": "orlaye/__pacache__\ntensorlaye/.DS_Store\n.DS_Store\ndist\nbuild/\ntensorlayer.egg-info\ndata/.DS_Store\n*.pyc\n*.gz\n"
  },
  {
    "path": "README.md",
    "content": "# U-Net Brain Tumor Segmentation \r\n\r\n🚀：Feb 2019 the data processing implementation in this repo is not the fastest way (code need update, contribution is welcome), you can use TensorFlow dataset API instead.\r\n\r\nThis repo show you how to train a U-Net for brain tumor segmentation. By default, you need to download the training set of [BRATS 2017](http://braintumorsegmentation.org) dataset, which have 210 HGG and 75 LGG volumes, and put the data folder along with all scripts.\r\n\r\n```bash\r\ndata\r\n  -- Brats17TrainingData\r\n  -- train_dev_all\r\nmodel.py\r\ntrain.py\r\n...\r\n```\r\n\r\n### About the data\r\nNote that according to the license, user have to apply the dataset from BRAST, please do **NOT** contact me for the dataset. Many thanks.\r\n\r\n<div align=\"center\">\r\n    <img src=\"https://github.com/zsdonghao/u-net-brain-tumor/blob/master/example/brain_tumor_data.png\" width=\"80%\" height=\"50%\"/>\r\n    <br>  \r\n    <em align=\"center\">Fig 1: Brain Image</em>  \r\n</div>\r\n\r\n* Each volume have 4 scanning images: FLAIR、T1、T1c and T2.\r\n* Each volume have 4 segmentation labels:\r\n\r\n```\r\nLabel 0: background\r\nLabel 1: necrotic and non-enhancing tumor\r\nLabel 2: edema \r\nLabel 4: enhancing tumor\r\n```\r\n\r\nThe `prepare_data_with_valid.py` split the training set into 2 folds for training and validating. By default, it will use only half of the data for the sake of training speed, if you want to use all data, just change `DATA_SIZE = 'half'` to `all`.\r\n\r\n### About the method\r\n\r\n- Network and Loss: In this experiment, as we use [dice loss](http://tensorlayer.readthedocs.io/en/latest/modules/cost.html#dice-coefficient) to train a network, one network only predict one labels (Label 1,2 or 4). We evaluate the performance using [hard dice](http://tensorlayer.readthedocs.io/en/latest/modules/cost.html#hard-dice-coefficient) and [IOU](http://tensorlayer.readthedocs.io/en/latest/modules/cost.html#iou-coefficient).\r\n\r\n- Data augmenation: Includes random left and right flip, rotation, shifting, shearing, zooming and the most important one -- [Elastic trasnformation](http://tensorlayer.readthedocs.io/en/latest/modules/prepro.html#elastic-transform), see [\"Automatic Brain Tumor Detection and Segmentation Using U-Net Based Fully Convolutional Networks\"](https://arxiv.org/pdf/1705.03820.pdf) for details.\r\n\r\n<div align=\"center\">\r\n    <img src=\"https://github.com/zsdonghao/u-net-brain-tumor/blob/master/example/brain_tumor_aug.png\" width=\"80%\" height=\"50%\"/>\r\n    <br>  \r\n    <em align=\"center\">Fig 2: Data augmentation</em>  \r\n</div>\r\n\r\n### Start training\r\n\r\nWe train HGG and LGG together, as one network only have one task, set the `task` to `all`, `necrotic`, `edema` or `enhance`, \"all\" means learn to segment all tumors.\r\n\r\n```\r\npython train.py --task=all\r\n```\r\n\r\nNote that, if the loss stick on 1 at the beginning, it means the network doesn't converge to near-perfect accuracy, please try restart it.\r\n\r\n### Citation\r\nIf you find this project useful, we would be grateful if you cite the TensorLayer paper：\r\n\r\n```\r\n@article{tensorlayer2017,\r\nauthor = {Dong, Hao and Supratak, Akara and Mai, Luo and Liu, Fangde and Oehmichen, Axel and Yu, Simiao and Guo, Yike},\r\njournal = {ACM Multimedia},\r\ntitle = {{TensorLayer: A Versatile Library for Efficient Deep Learning Development}},\r\nurl = {http://tensorlayer.org},\r\nyear = {2017}\r\n}\r\n```\r\n\r\n"
  },
  {
    "path": "model.py",
    "content": "import tensorflow as tf\r\nimport tensorlayer as tl\r\nfrom tensorlayer.layers import *\r\nimport numpy as np\r\n\r\n\r\nfrom tensorlayer.layers import *\r\ndef u_net(x, is_train=False, reuse=False, n_out=1):\r\n    _, nx, ny, nz = x.get_shape().as_list()\r\n    with tf.variable_scope(\"u_net\", reuse=reuse):\r\n        tl.layers.set_name_reuse(reuse)\r\n        inputs = InputLayer(x, name='inputs')\r\n        conv1 = Conv2d(inputs, 64, (3, 3), act=tf.nn.relu, name='conv1_1')\r\n        conv1 = Conv2d(conv1, 64, (3, 3), act=tf.nn.relu, name='conv1_2')\r\n        pool1 = MaxPool2d(conv1, (2, 2), name='pool1')\r\n        conv2 = Conv2d(pool1, 128, (3, 3), act=tf.nn.relu, name='conv2_1')\r\n        conv2 = Conv2d(conv2, 128, (3, 3), act=tf.nn.relu, name='conv2_2')\r\n        pool2 = MaxPool2d(conv2, (2, 2), name='pool2')\r\n        conv3 = Conv2d(pool2, 256, (3, 3), act=tf.nn.relu, name='conv3_1')\r\n        conv3 = Conv2d(conv3, 256, (3, 3), act=tf.nn.relu, name='conv3_2')\r\n        pool3 = MaxPool2d(conv3, (2, 2), name='pool3')\r\n        conv4 = Conv2d(pool3, 512, (3, 3), act=tf.nn.relu, name='conv4_1')\r\n        conv4 = Conv2d(conv4, 512, (3, 3), act=tf.nn.relu, name='conv4_2')\r\n        pool4 = MaxPool2d(conv4, (2, 2), name='pool4')\r\n        conv5 = Conv2d(pool4, 1024, (3, 3), act=tf.nn.relu, name='conv5_1')\r\n        conv5 = Conv2d(conv5, 1024, (3, 3), act=tf.nn.relu, name='conv5_2')\r\n\r\n        up4 = DeConv2d(conv5, 512, (3, 3), (nx/8, ny/8), (2, 2), name='deconv4')\r\n        up4 = ConcatLayer([up4, conv4], 3, name='concat4')\r\n        conv4 = Conv2d(up4, 512, (3, 3), act=tf.nn.relu, name='uconv4_1')\r\n        conv4 = Conv2d(conv4, 512, (3, 3), act=tf.nn.relu, name='uconv4_2')\r\n        up3 = DeConv2d(conv4, 256, (3, 3), (nx/4, ny/4), (2, 2), name='deconv3')\r\n        up3 = ConcatLayer([up3, conv3], 3, name='concat3')\r\n        conv3 = Conv2d(up3, 256, (3, 3), act=tf.nn.relu, name='uconv3_1')\r\n        conv3 = Conv2d(conv3, 256, (3, 3), act=tf.nn.relu, name='uconv3_2')\r\n        up2 = DeConv2d(conv3, 128, (3, 3), (nx/2, ny/2), (2, 2), name='deconv2')\r\n        up2 = ConcatLayer([up2, conv2], 3, name='concat2')\r\n        conv2 = Conv2d(up2, 128, (3, 3), act=tf.nn.relu,  name='uconv2_1')\r\n        conv2 = Conv2d(conv2, 128, (3, 3), act=tf.nn.relu, name='uconv2_2')\r\n        up1 = DeConv2d(conv2, 64, (3, 3), (nx/1, ny/1), (2, 2), name='deconv1')\r\n        up1 = ConcatLayer([up1, conv1] , 3, name='concat1')\r\n        conv1 = Conv2d(up1, 64, (3, 3), act=tf.nn.relu, name='uconv1_1')\r\n        conv1 = Conv2d(conv1, 64, (3, 3), act=tf.nn.relu, name='uconv1_2')\r\n        conv1 = Conv2d(conv1, n_out, (1, 1), act=tf.nn.sigmoid, name='uconv1')\r\n    return conv1\r\n\r\n# def u_net(x, is_train=False, reuse=False, pad='SAME', n_out=2):\r\n#     \"\"\" Original U-Net for cell segmentataion\r\n#     http://lmb.informatik.uni-freiburg.de/people/ronneber/u-net/\r\n#     Original x is [batch_size, 572, 572, ?], pad is VALID\r\n#     \"\"\"\r\n#     from tensorlayer.layers import InputLayer, Conv2d, MaxPool2d, DeConv2d, ConcatLayer\r\n#     nx = int(x._shape[1])\r\n#     ny = int(x._shape[2])\r\n#     nz = int(x._shape[3])\r\n#     print(\" * Input: size of image: %d %d %d\" % (nx, ny, nz))\r\n#\r\n#     w_init = tf.truncated_normal_initializer(stddev=0.01)\r\n#     b_init = tf.constant_initializer(value=0.0)\r\n#     with tf.variable_scope(\"u_net\", reuse=reuse):\r\n#         tl.layers.set_name_reuse(reuse)\r\n#         inputs = InputLayer(x, name='inputs')\r\n#\r\n#         conv1 = Conv2d(inputs, 64, (3, 3), act=tf.nn.relu, padding=pad,\r\n#                     W_init=w_init, b_init=b_init, name='conv1_1')\r\n#         conv1 = Conv2d(conv1, 64, (3, 3), act=tf.nn.relu, padding=pad,\r\n#                     W_init=w_init, b_init=b_init, name='conv1_2')\r\n#         pool1 = MaxPool2d(conv1, (2, 2), padding=pad, name='pool1')\r\n#\r\n#         conv2 = Conv2d(pool1, 128, (3, 3), act=tf.nn.relu, padding=pad,\r\n#                     W_init=w_init, b_init=b_init, name='conv2_1')\r\n#         conv2 = Conv2d(conv2, 128, (3, 3), act=tf.nn.relu, padding=pad,\r\n#                     W_init=w_init, b_init=b_init, name='conv2_2')\r\n#         pool2 = MaxPool2d(conv2, (2, 2), padding=pad, name='pool2')\r\n#\r\n#         conv3 = Conv2d(pool2, 256, (3, 3), act=tf.nn.relu, padding=pad,\r\n#                     W_init=w_init, b_init=b_init, name='conv3_1')\r\n#         conv3 = Conv2d(conv3, 256, (3, 3), act=tf.nn.relu, padding=pad,\r\n#                     W_init=w_init, b_init=b_init, name='conv3_2')\r\n#         pool3 = MaxPool2d(conv3, (2, 2), padding=pad, name='pool3')\r\n#\r\n#         conv4 = Conv2d(pool3, 512, (3, 3), act=tf.nn.relu, padding=pad,\r\n#                     W_init=w_init, b_init=b_init, name='conv4_1')\r\n#         conv4 = Conv2d(conv4, 512, (3, 3), act=tf.nn.relu, padding=pad,\r\n#                     W_init=w_init, b_init=b_init, name='conv4_2')\r\n#         pool4 = MaxPool2d(conv4, (2, 2), padding=pad, name='pool4')\r\n#\r\n#         conv5 = Conv2d(pool4, 1024, (3, 3), act=tf.nn.relu, padding=pad,\r\n#                     W_init=w_init, b_init=b_init, name='conv5_1')\r\n#         conv5 = Conv2d(conv5, 1024, (3, 3), act=tf.nn.relu, padding=pad,\r\n#                     W_init=w_init, b_init=b_init, name='conv5_2')\r\n#\r\n#         print(\" * After conv: %s\" % conv5.outputs)\r\n#\r\n#         up4 = DeConv2d(conv5, 512, (3, 3), out_size = (nx/8, ny/8),\r\n#                     strides=(2, 2), padding=pad, act=None,\r\n#                     W_init=w_init, b_init=b_init, name='deconv4')\r\n#         up4 = ConcatLayer([up4, conv4], concat_dim=3, name='concat4')\r\n#         conv4 = Conv2d(up4, 512, (3, 3), act=tf.nn.relu, padding=pad,\r\n#                     W_init=w_init, b_init=b_init, name='uconv4_1')\r\n#         conv4 = Conv2d(conv4, 512, (3, 3), act=tf.nn.relu, padding=pad,\r\n#                     W_init=w_init, b_init=b_init, name='uconv4_2')\r\n#\r\n#         up3 = DeConv2d(conv4, 256, (3, 3), out_size = (nx/4, ny/4),\r\n#                     strides=(2, 2), padding=pad, act=None,\r\n#                     W_init=w_init, b_init=b_init, name='deconv3')\r\n#         up3 = ConcatLayer([up3, conv3], concat_dim=3, name='concat3')\r\n#         conv3 = Conv2d(up3, 256, (3, 3), act=tf.nn.relu, padding=pad,\r\n#                     W_init=w_init, b_init=b_init, name='uconv3_1')\r\n#         conv3 = Conv2d(conv3, 256, (3, 3), act=tf.nn.relu, padding=pad,\r\n#                     W_init=w_init, b_init=b_init, name='uconv3_2')\r\n#\r\n#         up2 = DeConv2d(conv3, 128, (3, 3), out_size=(nx/2, ny/2),\r\n#                     strides=(2, 2), padding=pad, act=None,\r\n#                     W_init=w_init, b_init=b_init, name='deconv2')\r\n#         up2 = ConcatLayer([up2, conv2] ,concat_dim=3, name='concat2')\r\n#         conv2 = Conv2d(up2, 128, (3, 3), act=tf.nn.relu, padding=pad,\r\n#                     W_init=w_init, b_init=b_init, name='uconv2_1')\r\n#         conv2 = Conv2d(conv2, 128, (3, 3), act=tf.nn.relu, padding=pad,\r\n#                     W_init=w_init, b_init=b_init, name='uconv2_2')\r\n#\r\n#         up1 = DeConv2d(conv2, 64, (3, 3), out_size=(nx/1, ny/1),\r\n#                     strides=(2, 2), padding=pad, act=None,\r\n#                     W_init=w_init, b_init=b_init, name='deconv1')\r\n#         up1 = ConcatLayer([up1, conv1] ,concat_dim=3, name='concat1')\r\n#         conv1 = Conv2d(up1, 64, (3, 3), act=tf.nn.relu, padding=pad,\r\n#                     W_init=w_init, b_init=b_init, name='uconv1_1')\r\n#         conv1 = Conv2d(conv1, 64, (3, 3), act=tf.nn.relu, padding=pad,\r\n#                     W_init=w_init, b_init=b_init, name='uconv1_2')\r\n#\r\n#         conv1 = Conv2d(conv1, n_out, (1, 1), act=tf.nn.sigmoid, name='uconv1')\r\n#         print(\" * Output: %s\" % conv1.outputs)\r\n#\r\n#         # logits0 = conv1.outputs[:,:,:,0]            # segmentataion\r\n#         # logits1 = conv1.outputs[:,:,:,1]            # edge\r\n#         # logits0 = tf.expand_dims(logits0, axis=3)\r\n#         # logits1 = tf.expand_dims(logits1, axis=3)\r\n#     return conv1\r\n\r\n\r\ndef u_net_bn(x, is_train=False, reuse=False, batch_size=None, pad='SAME', n_out=1):\r\n    \"\"\"image to image translation via conditional adversarial learning\"\"\"\r\n    nx = int(x._shape[1])\r\n    ny = int(x._shape[2])\r\n    nz = int(x._shape[3])\r\n    print(\" * Input: size of image: %d %d %d\" % (nx, ny, nz))\r\n\r\n    w_init = tf.truncated_normal_initializer(stddev=0.01)\r\n    b_init = tf.constant_initializer(value=0.0)\r\n    gamma_init=tf.random_normal_initializer(1., 0.02)\r\n    with tf.variable_scope(\"u_net\", reuse=reuse):\r\n        tl.layers.set_name_reuse(reuse)\r\n        inputs = InputLayer(x, name='inputs')\r\n\r\n        conv1 = Conv2d(inputs, 64, (4, 4), (2, 2), act=None, padding=pad, W_init=w_init, b_init=b_init, name='conv1')\r\n        conv2 = Conv2d(conv1, 128, (4, 4), (2, 2), act=None, padding=pad, W_init=w_init, b_init=b_init, name='conv2')\r\n        conv2 = BatchNormLayer(conv2, act=lambda x: tl.act.lrelu(x, 0.2), is_train=is_train, gamma_init=gamma_init, name='bn2')\r\n\r\n        conv3 = Conv2d(conv2, 256, (4, 4), (2, 2), act=None, padding=pad, W_init=w_init, b_init=b_init, name='conv3')\r\n        conv3 = BatchNormLayer(conv3, act=lambda x: tl.act.lrelu(x, 0.2), is_train=is_train, gamma_init=gamma_init, name='bn3')\r\n\r\n        conv4 = Conv2d(conv3, 512, (4, 4), (2, 2), act=None, padding=pad, W_init=w_init, b_init=b_init, name='conv4')\r\n        conv4 = BatchNormLayer(conv4, act=lambda x: tl.act.lrelu(x, 0.2), is_train=is_train, gamma_init=gamma_init, name='bn4')\r\n\r\n        conv5 = Conv2d(conv4, 512, (4, 4), (2, 2), act=None, padding=pad, W_init=w_init, b_init=b_init, name='conv5')\r\n        conv5 = BatchNormLayer(conv5, act=lambda x: tl.act.lrelu(x, 0.2), is_train=is_train, gamma_init=gamma_init, name='bn5')\r\n\r\n        conv6 = Conv2d(conv5, 512, (4, 4), (2, 2), act=None, padding=pad, W_init=w_init, b_init=b_init, name='conv6')\r\n        conv6 = BatchNormLayer(conv6, act=lambda x: tl.act.lrelu(x, 0.2), is_train=is_train, gamma_init=gamma_init, name='bn6')\r\n\r\n        conv7 = Conv2d(conv6, 512, (4, 4), (2, 2), act=None, padding=pad, W_init=w_init, b_init=b_init, name='conv7')\r\n        conv7 = BatchNormLayer(conv7, act=lambda x: tl.act.lrelu(x, 0.2), is_train=is_train, gamma_init=gamma_init, name='bn7')\r\n\r\n        conv8 = Conv2d(conv7, 512, (4, 4), (2, 2), act=lambda x: tl.act.lrelu(x, 0.2), padding=pad, W_init=w_init, b_init=b_init, name='conv8')\r\n        print(\" * After conv: %s\" % conv8.outputs)\r\n        # exit()\r\n        # print(nx/8)\r\n        up7 = DeConv2d(conv8, 512, (4, 4), out_size=(2, 2), strides=(2, 2),\r\n                                    padding=pad, act=None, batch_size=batch_size, W_init=w_init, b_init=b_init, name='deconv7')\r\n        up7 = BatchNormLayer(up7, act=tf.nn.relu, is_train=is_train, gamma_init=gamma_init, name='dbn7')\r\n\r\n        # print(up6.outputs)\r\n        up6 = ConcatLayer([up7, conv7], concat_dim=3, name='concat6')\r\n        up6 = DeConv2d(up6, 1024, (4, 4), out_size=(4, 4), strides=(2, 2),\r\n                                    padding=pad, act=None, batch_size=batch_size, W_init=w_init, b_init=b_init, name='deconv6')\r\n        up6 = BatchNormLayer(up6, act=tf.nn.relu, is_train=is_train, gamma_init=gamma_init, name='dbn6')\r\n        # print(up6.outputs)\r\n        # exit()\r\n\r\n        up5 = ConcatLayer([up6, conv6], concat_dim=3, name='concat5')\r\n        up5 = DeConv2d(up5, 1024, (4, 4), out_size=(8, 8), strides=(2, 2),\r\n                                    padding=pad, act=None, batch_size=batch_size, W_init=w_init, b_init=b_init, name='deconv5')\r\n        up5 = BatchNormLayer(up5, act=tf.nn.relu, is_train=is_train, gamma_init=gamma_init, name='dbn5')\r\n        # print(up5.outputs)\r\n        # exit()\r\n\r\n        up4 = ConcatLayer([up5, conv5] ,concat_dim=3, name='concat4')\r\n        up4 = DeConv2d(up4, 1024, (4, 4), out_size=(15, 15), strides=(2, 2),\r\n                                    padding=pad, act=None, batch_size=batch_size, W_init=w_init, b_init=b_init, name='deconv4')\r\n        up4 = BatchNormLayer(up4, act=tf.nn.relu, is_train=is_train, gamma_init=gamma_init, name='dbn4')\r\n\r\n        up3 = ConcatLayer([up4, conv4] ,concat_dim=3, name='concat3')\r\n        up3 = DeConv2d(up3, 256, (4, 4), out_size=(30, 30), strides=(2, 2),\r\n                                    padding=pad, act=None, batch_size=batch_size, W_init=w_init, b_init=b_init, name='deconv3')\r\n        up3 = BatchNormLayer(up3, act=tf.nn.relu, is_train=is_train, gamma_init=gamma_init, name='dbn3')\r\n\r\n        up2 = ConcatLayer([up3, conv3] ,concat_dim=3, name='concat2')\r\n        up2 = DeConv2d(up2, 128, (4, 4), out_size=(60, 60), strides=(2, 2),\r\n                                    padding=pad, act=None, batch_size=batch_size, W_init=w_init, b_init=b_init, name='deconv2')\r\n        up2 = BatchNormLayer(up2, act=tf.nn.relu, is_train=is_train, gamma_init=gamma_init, name='dbn2')\r\n\r\n        up1 = ConcatLayer([up2, conv2] ,concat_dim=3, name='concat1')\r\n        up1 = DeConv2d(up1, 64, (4, 4), out_size=(120, 120), strides=(2, 2),\r\n                                    padding=pad, act=None, batch_size=batch_size, W_init=w_init, b_init=b_init, name='deconv1')\r\n        up1 = BatchNormLayer(up1, act=tf.nn.relu, is_train=is_train, gamma_init=gamma_init, name='dbn1')\r\n\r\n        up0 = ConcatLayer([up1, conv1] ,concat_dim=3, name='concat0')\r\n        up0 = DeConv2d(up0, 64, (4, 4), out_size=(240, 240), strides=(2, 2),\r\n                                    padding=pad, act=None, batch_size=batch_size, W_init=w_init, b_init=b_init, name='deconv0')\r\n        up0 = BatchNormLayer(up0, act=tf.nn.relu, is_train=is_train, gamma_init=gamma_init, name='dbn0')\r\n        # print(up0.outputs)\r\n        # exit()\r\n\r\n        out = Conv2d(up0, n_out, (1, 1), act=tf.nn.sigmoid, name='out')\r\n\r\n        print(\" * Output: %s\" % out.outputs)\r\n        # exit()\r\n\r\n    return out\r\n\r\n## old implementation\r\n# def u_net_2d_64_1024_deconv(x, n_out=2):\r\n#     from tensorlayer.layers import InputLayer, Conv2d, MaxPool2d, DeConv2d, ConcatLayer\r\n#     nx = int(x._shape[1])\r\n#     ny = int(x._shape[2])\r\n#     nz = int(x._shape[3])\r\n#     print(\" * Input: size of image: %d %d %d\" % (nx, ny, nz))\r\n#\r\n#     w_init = tf.truncated_normal_initializer(stddev=0.01)\r\n#     b_init = tf.constant_initializer(value=0.0)\r\n#     inputs = InputLayer(x, name='inputs')\r\n#\r\n#     conv1 = Conv2d(inputs, 64, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv1_1')\r\n#     conv1 = Conv2d(conv1, 64, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv1_2')\r\n#     pool1 = MaxPool2d(conv1, (2, 2), padding='SAME', name='pool1')\r\n#\r\n#     conv2 = Conv2d(pool1, 128, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv2_1')\r\n#     conv2 = Conv2d(conv2, 128, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv2_2')\r\n#     pool2 = MaxPool2d(conv2, (2, 2), padding='SAME', name='pool2')\r\n#\r\n#     conv3 = Conv2d(pool2, 256, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv3_1')\r\n#     conv3 = Conv2d(conv3, 256, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv3_2')\r\n#     pool3 = MaxPool2d(conv3, (2, 2), padding='SAME', name='pool3')\r\n#\r\n#     conv4 = Conv2d(pool3, 512, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv4_1')\r\n#     conv4 = Conv2d(conv4, 512, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv4_2')\r\n#     pool4 = MaxPool2d(conv4, (2, 2), padding='SAME', name='pool4')\r\n#\r\n#     conv5 = Conv2d(pool4, 1024, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv5_1')\r\n#     conv5 = Conv2d(conv5, 1024, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv5_2')\r\n#\r\n#     print(\" * After conv: %s\" % conv5.outputs)\r\n#\r\n#     up4 = DeConv2d(conv5, 512, (3, 3), out_size = (nx/8, ny/8), strides = (2, 2),\r\n#                                 padding = 'SAME', act=None, W_init=w_init, b_init=b_init, name='deconv4')\r\n#     up4 = ConcatLayer([up4, conv4], concat_dim=3, name='concat4')\r\n#     conv4 = Conv2d(up4, 512, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='uconv4_1')\r\n#     conv4 = Conv2d(conv4, 512, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='uconv4_2')\r\n#\r\n#     up3 = DeConv2d(conv4, 256, (3, 3), out_size = (nx/4, ny/4), strides = (2, 2),\r\n#                                 padding = 'SAME', act=None, W_init=w_init, b_init=b_init, name='deconv3')\r\n#     up3 = ConcatLayer([up3, conv3], concat_dim=3, name='concat3')\r\n#     conv3 = Conv2d(up3, 256, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='uconv3_1')\r\n#     conv3 = Conv2d(conv3, 256, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='uconv3_2')\r\n#\r\n#     up2 = DeConv2d(conv3, 128, (3, 3), out_size = (nx/2, ny/2), strides = (2, 2),\r\n#                                 padding = 'SAME', act=None, W_init=w_init, b_init=b_init, name='deconv2')\r\n#     up2 = ConcatLayer([up2, conv2] ,concat_dim=3, name='concat2')\r\n#     conv2 = Conv2d(up2, 128, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='uconv2_1')\r\n#     conv2 = Conv2d(conv2, 128, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='uconv2_2')\r\n#\r\n#     up1 = DeConv2d(conv2, 64, (3, 3), out_size = (nx/1, ny/1), strides = (2, 2),\r\n#                                 padding = 'SAME', act=None, W_init=w_init, b_init=b_init, name='deconv1')\r\n#     up1 = ConcatLayer([up1, conv1] ,concat_dim=3, name='concat1')\r\n#     conv1 = Conv2d(up1, 64, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='uconv1_1')\r\n#     conv1 = Conv2d(conv1, 64, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='uconv1_2')\r\n#\r\n#     conv1 = Conv2d(conv1, n_out, (1, 1), act=None, name='uconv1')\r\n#     print(\" * Output: %s\" % conv1.outputs)\r\n#     outputs = tl.act.pixel_wise_softmax(conv1.outputs)\r\n#     return conv1, outputs\r\n#\r\n#\r\n# def u_net_2d_32_1024_upsam(x, n_out=2):\r\n#     \"\"\"\r\n#     https://github.com/jocicmarko/ultrasound-nerve-segmentation\r\n#     \"\"\"\r\n#     from tensorlayer.layers import InputLayer, Conv2d, MaxPool2d, DeConv2d, ConcatLayer\r\n#     batch_size = int(x._shape[0])\r\n#     nx = int(x._shape[1])\r\n#     ny = int(x._shape[2])\r\n#     nz = int(x._shape[3])\r\n#     print(\" * Input: size of image: %d %d %d\" % (nx, ny, nz))\r\n#     ## define initializer\r\n#     w_init = tf.truncated_normal_initializer(stddev=0.01)\r\n#     b_init = tf.constant_initializer(value=0.0)\r\n#     inputs = InputLayer(x, name='inputs')\r\n#\r\n#     conv1 = Conv2d(inputs, 32, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv1_1')\r\n#     conv1 = Conv2d(conv1, 32, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv1_2')\r\n#     pool1 = MaxPool2d(conv1, (2, 2), padding='SAME', name='pool1')\r\n#\r\n#     conv2 = Conv2d(pool1, 64, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv2_1')\r\n#     conv2 = Conv2d(conv2, 64, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv2_2')\r\n#     pool2 = MaxPool2d(conv2, (2,2), padding='SAME', name='pool2')\r\n#\r\n#     conv3 = Conv2d(pool2, 128, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv3_1')\r\n#     conv3 = Conv2d(conv3, 128, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv3_2')\r\n#     pool3 = MaxPool2d(conv3, (2, 2), padding='SAME', name='pool3')\r\n#\r\n#     conv4 = Conv2d(pool3, 256, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv4_1')\r\n#     conv4 = Conv2d(conv4, 256, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv4_2')\r\n#     pool4 = MaxPool2d(conv4, (2, 2), padding='SAME', name='pool4')\r\n#\r\n#     conv5 = Conv2d(pool4, 512, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv5_1')\r\n#     conv5 = Conv2d(conv5, 512, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv5_2')\r\n#     pool5 = MaxPool2d(conv5, (2, 2), padding='SAME', name='pool6')\r\n#\r\n#     # hao add\r\n#     conv6 = Conv2d(pool5, 1024, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv6_1')\r\n#     conv6 = Conv2d(conv6, 1024, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv6_2')\r\n#\r\n#     print(\" * After conv: %s\" % conv6.outputs)\r\n#\r\n#     # hao add\r\n#     up7 = UpSampling2dLayer(conv6, (15, 15), is_scale=False, method=1, name='up7')\r\n#     up7 =  ConcatLayer([up7, conv5], concat_dim=3, name='concat7')\r\n#     conv7 = Conv2d(up7, 512, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv7_1')\r\n#     conv7 = Conv2d(conv7, 512, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv7_2')\r\n#\r\n#     # print(nx/8,ny/8) # 30 30\r\n#     up8 = UpSampling2dLayer(conv7, (2, 2), method=1, name='up8')\r\n#     up8 = ConcatLayer([up8, conv4], concat_dim=3, name='concat8')\r\n#     conv8 = Conv2d(up8, 256, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv8_1')\r\n#     conv8 = Conv2d(conv8, 256, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv8_2')\r\n#\r\n#     up9 = UpSampling2dLayer(conv8, (2, 2), method=1, name='up9')\r\n#     up9 = ConcatLayer([up9, conv3] ,concat_dim=3, name='concat9')\r\n#     conv9 = Conv2d(up9, 128, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv9_1')\r\n#     conv9 = Conv2d(conv9, 128, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv9_2')\r\n#\r\n#     up10 = UpSampling2dLayer(conv9, (2, 2), method=1, name='up10')\r\n#     up10 = ConcatLayer([up10, conv2] ,concat_dim=3, name='concat10')\r\n#     conv10 = Conv2d(up10, 64, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv10_1')\r\n#     conv10 = Conv2d(conv10, 64, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv10_2')\r\n#\r\n#     up11 = UpSampling2dLayer(conv10, (2, 2), method=1, name='up11')\r\n#     up11 = ConcatLayer([up11, conv1] ,concat_dim=3, name='concat11')\r\n#     conv11 = Conv2d(up11, 32, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv11_1')\r\n#     conv11 = Conv2d(conv11, 32, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv11_2')\r\n#\r\n#     conv12 = Conv2d(conv11, n_out, (1, 1), act=None, name='conv12')\r\n#     print(\" * Output: %s\" % conv12.outputs)\r\n#     outputs = tl.act.pixel_wise_softmax(conv12.outputs)\r\n#     return conv10, outputs\r\n#\r\n#\r\n# def u_net_2d_32_512_upsam(x, n_out=2):\r\n#     \"\"\"\r\n#     https://github.com/jocicmarko/ultrasound-nerve-segmentation\r\n#     \"\"\"\r\n#     from tensorlayer.layers import InputLayer, Conv2d, MaxPool2d, DeConv2d, ConcatLayer\r\n#     batch_size = int(x._shape[0])\r\n#     nx = int(x._shape[1])\r\n#     ny = int(x._shape[2])\r\n#     nz = int(x._shape[3])\r\n#     print(\" * Input: size of image: %d %d %d\" % (nx, ny, nz))\r\n#     ## define initializer\r\n#     w_init = tf.truncated_normal_initializer(stddev=0.01)\r\n#     b_init = tf.constant_initializer(value=0.0)\r\n#     inputs = InputLayer(x, name='inputs')\r\n#     # inputs = Input((1, img_rows, img_cols))\r\n#     conv1 = Conv2d(inputs, 32, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv1_1')\r\n#     # print(conv1.outputs) # (10, 240, 240, 32)\r\n#     # conv1 = Convolution2D(32, 3, 3, activation='relu', border_mode='same')(inputs)\r\n#     conv1 = Conv2d(conv1, 32, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv1_2')\r\n#     # print(conv1.outputs)    # (10, 240, 240, 32)\r\n#     # conv1 = Convolution2D(32, 3, 3, activation='relu', border_mode='same')(conv1)\r\n#     pool1 = MaxPool2d(conv1, (2, 2), padding='SAME', name='pool1')\r\n#     # pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)\r\n#     # print(pool1.outputs)    # (10, 120, 120, 32)\r\n#     # exit()\r\n#     conv2 = Conv2d(pool1, 64, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv2_1')\r\n#     # conv2 = Convolution2D(64, 3, 3, activation='relu', border_mode='same')(pool1)\r\n#     conv2 = Conv2d(conv2, 64, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv2_2')\r\n#     # conv2 = Convolution2D(64, 3, 3, activation='relu', border_mode='same')(conv2)\r\n#     pool2 = MaxPool2d(conv2, (2,2), padding='SAME', name='pool2')\r\n#     # pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)\r\n#\r\n#     conv3 = Conv2d(pool2, 128, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv3_1')\r\n#     # conv3 = Convolution2D(128, 3, 3, activation='relu', border_mode='same')(pool2)\r\n#     conv3 = Conv2d(conv3, 128, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv3_2')\r\n#     # conv3 = Convolution2D(128, 3, 3, activation='relu', border_mode='same')(conv3)\r\n#     pool3 = MaxPool2d(conv3, (2, 2), padding='SAME', name='pool3')\r\n#     # pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)\r\n#     # print(pool3.outputs)   # (10, 30, 30, 64)\r\n#\r\n#     conv4 = Conv2d(pool3, 256, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv4_1')\r\n#     # print(conv4.outputs)    # (10, 30, 30, 256)\r\n#     # conv4 = Convolution2D(256, 3, 3, activation='relu', border_mode='same')(pool3)\r\n#     conv4 = Conv2d(conv4, 256, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv4_2')\r\n#     # print(conv4.outputs)    # (10, 30, 30, 256) != (10, 30, 30, 512)\r\n#     # conv4 = Convolution2D(256, 3, 3, activation='relu', border_mode='same')(conv4)\r\n#     pool4 = MaxPool2d(conv4, (2, 2), padding='SAME', name='pool4')\r\n#     # pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)\r\n#\r\n#     conv5 = Conv2d(pool4, 512, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv5_1')\r\n#     # conv5 = Convolution2D(512, 3, 3, activation='relu', border_mode='same')(pool4)\r\n#     conv5 = Conv2d(conv5, 512, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv5_2')\r\n#     # conv5 = Convolution2D(512, 3, 3, activation='relu', border_mode='same')(conv5)\r\n#     # print(conv5.outputs)    # (10, 15, 15, 512)\r\n#     print(\" * After conv: %s\" % conv5.outputs)\r\n#     # print(nx/8,ny/8) # 30 30\r\n#     up6 = UpSampling2dLayer(conv5, (2, 2), name='up6')\r\n#     # print(up6.outputs)  # (10, 30, 30, 512) == (10, 30, 30, 512)\r\n#     up6 = ConcatLayer([up6, conv4], concat_dim=3, name='concat6')\r\n#     # print(up6.outputs)  # (10, 30, 30, 768)\r\n#     # up6 = merge([UpSampling2D(size=(2, 2))(conv5), conv4], mode='concat', concat_axis=1)\r\n#     conv6 = Conv2d(up6, 256, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv6_1')\r\n#     # conv6 = Convolution2D(256, 3, 3, activation='relu', border_mode='same')(up6)\r\n#     conv6 = Conv2d(conv6, 256, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv6_2')\r\n#     # conv6 = Convolution2D(256, 3, 3, activation='relu', border_mode='same')(conv6)\r\n#\r\n#     up7 = UpSampling2dLayer(conv6, (2, 2), name='up7')\r\n#     up7 = ConcatLayer([up7, conv3] ,concat_dim=3, name='concat7')\r\n#     # up7 = merge([UpSampling2D(size=(2, 2))(conv6), conv3], mode='concat', concat_axis=1)\r\n#     conv7 = Conv2d(up7, 128, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv7_1')\r\n#     # conv7 = Convolution2D(128, 3, 3, activation='relu', border_mode='same')(up7)\r\n#     conv7 = Conv2d(conv7, 128, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv7_2')\r\n#     # conv7 = Convolution2D(128, 3, 3, activation='relu', border_mode='same')(conv7)\r\n#\r\n#     up8 = UpSampling2dLayer(conv7, (2, 2), name='up8')\r\n#     up8 = ConcatLayer([up8, conv2] ,concat_dim=3, name='concat8')\r\n#     # up8 = merge([UpSampling2D(size=(2, 2))(conv7), conv2], mode='concat', concat_axis=1)\r\n#     conv8 = Conv2d(up8, 64, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv8_1')\r\n#     # conv8 = Convolution2D(64, 3, 3, activation='relu', border_mode='same')(up8)\r\n#     conv8 = Conv2d(conv8, 64, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv8_2')\r\n#     # conv8 = Convolution2D(64, 3, 3, activation='relu', border_mode='same')(conv8)\r\n#\r\n#     up9 = UpSampling2dLayer(conv8, (2, 2), name='up9')\r\n#     up9 = ConcatLayer([up9, conv1] ,concat_dim=3, name='concat9')\r\n#     # up9 = merge([UpSampling2D(size=(2, 2))(conv8), conv1], mode='concat', concat_axis=1)\r\n#     conv9 = Conv2d(up9, 32, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv9_1')\r\n#     # conv9 = Convolution2D(32, 3, 3, activation='relu', border_mode='same')(up9)\r\n#     conv9 = Conv2d(conv9, 32, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv9_2')\r\n#     # conv9 = Convolution2D(32, 3, 3, activation='relu', border_mode='same')(conv9)\r\n#\r\n#     conv10 = Conv2d(conv9, n_out, (1, 1), act=None, name='conv9')\r\n#     # conv10 = Convolution2D(1, 1, 1, activation='sigmoid')(conv9)\r\n#     print(\" * Output: %s\" % conv10.outputs)\r\n#     outputs = tl.act.pixel_wise_softmax(conv10.outputs)\r\n#     return conv10, outputs\r\n\r\n\r\nif __name__ == \"__main__\":\r\n    pass\r\n    # main()\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n#\r\n"
  },
  {
    "path": "prepare_data_with_valid.py",
    "content": "import tensorlayer as tl\r\nimport numpy as np\r\nimport os, csv, random, gc, pickle\r\nimport nibabel as nib\r\n\r\n\r\n\"\"\"\r\nIn seg file\r\n--------------\r\nLabel 1: necrotic and non-enhancing tumor\r\nLabel 2: edema \r\nLabel 4: enhancing tumor\r\nLabel 0: background\r\n\r\nMRI\r\n-------\r\nwhole/complete tumor: 1 2 4\r\ncore: 1 4\r\nenhance: 4\r\n\"\"\"\r\n###============================= SETTINGS ===================================###\r\nDATA_SIZE = 'half' # (small, half or all)\r\n\r\nsave_dir = \"data/train_dev_all/\"\r\nif not os.path.exists(save_dir):\r\n    os.makedirs(save_dir)\r\n\r\nHGG_data_path = \"data/Brats17TrainingData/HGG\"\r\nLGG_data_path = \"data/Brats17TrainingData/LGG\"\r\nsurvival_csv_path = \"data/Brats17TrainingData/survival_data.csv\"\r\n###==========================================================================###\r\n\r\nsurvival_id_list = []\r\nsurvival_age_list =[]\r\nsurvival_peroid_list = []\r\n\r\nwith open(survival_csv_path, 'r') as f:\r\n    reader = csv.reader(f)\r\n    next(reader)\r\n    for idx, content in enumerate(reader):\r\n        survival_id_list.append(content[0])\r\n        survival_age_list.append(float(content[1]))\r\n        survival_peroid_list.append(float(content[2]))\r\n\r\nprint(len(survival_id_list)) #163\r\n\r\nif DATA_SIZE == 'all':\r\n    HGG_path_list = tl.files.load_folder_list(path=HGG_data_path)\r\n    LGG_path_list = tl.files.load_folder_list(path=LGG_data_path)\r\nelif DATA_SIZE == 'half':\r\n    HGG_path_list = tl.files.load_folder_list(path=HGG_data_path)[0:100]# DEBUG WITH SMALL DATA\r\n    LGG_path_list = tl.files.load_folder_list(path=LGG_data_path)[0:30] # DEBUG WITH SMALL DATA\r\nelif DATA_SIZE == 'small':\r\n    HGG_path_list = tl.files.load_folder_list(path=HGG_data_path)[0:50] # DEBUG WITH SMALL DATA\r\n    LGG_path_list = tl.files.load_folder_list(path=LGG_data_path)[0:20] # DEBUG WITH SMALL DATA\r\nelse:\r\n    exit(\"Unknow DATA_SIZE\")\r\nprint(len(HGG_path_list), len(LGG_path_list)) #210 #75\r\n\r\nHGG_name_list = [os.path.basename(p) for p in HGG_path_list]\r\nLGG_name_list = [os.path.basename(p) for p in LGG_path_list]\r\n\r\nsurvival_id_from_HGG = []\r\nsurvival_id_from_LGG = []\r\nfor i in survival_id_list:\r\n    if i in HGG_name_list:\r\n        survival_id_from_HGG.append(i)\r\n    elif i in LGG_name_list:\r\n        survival_id_from_LGG.append(i)\r\n    else:\r\n        print(i)\r\n\r\nprint(len(survival_id_from_HGG), len(survival_id_from_LGG)) #163, 0\r\n\r\n# use 42 from 210 (in 163 subset) and 15 from 75 as 0.8/0.2 train/dev split\r\n\r\n# use 126/42/42 from 210 (in 163 subset) and 45/15/15 from 75 as 0.6/0.2/0.2 train/dev/test split\r\nindex_HGG = list(range(0, len(survival_id_from_HGG)))\r\nindex_LGG = list(range(0, len(LGG_name_list)))\r\n# random.shuffle(index_HGG)\r\n# random.shuffle(index_HGG)\r\n\r\nif DATA_SIZE == 'all':\r\n    dev_index_HGG = index_HGG[-84:-42]\r\n    test_index_HGG = index_HGG[-42:]\r\n    tr_index_HGG = index_HGG[:-84]\r\n    dev_index_LGG = index_LGG[-30:-15]\r\n    test_index_LGG = index_LGG[-15:]\r\n    tr_index_LGG = index_LGG[:-30]\r\nelif DATA_SIZE == 'half':\r\n    dev_index_HGG = index_HGG[-30:]  # DEBUG WITH SMALL DATA\r\n    test_index_HGG = index_HGG[-5:]\r\n    tr_index_HGG = index_HGG[:-30]\r\n    dev_index_LGG = index_LGG[-10:]  # DEBUG WITH SMALL DATA\r\n    test_index_LGG = index_LGG[-5:]\r\n    tr_index_LGG = index_LGG[:-10]\r\nelif DATA_SIZE == 'small':\r\n    dev_index_HGG = index_HGG[35:42]   # DEBUG WITH SMALL DATA\r\n    # print(index_HGG, dev_index_HGG)\r\n    # exit()\r\n    test_index_HGG = index_HGG[41:42]\r\n    tr_index_HGG = index_HGG[0:35]\r\n    dev_index_LGG = index_LGG[7:10]    # DEBUG WITH SMALL DATA\r\n    test_index_LGG = index_LGG[9:10]\r\n    tr_index_LGG = index_LGG[0:7]\r\n\r\nsurvival_id_dev_HGG = [survival_id_from_HGG[i] for i in dev_index_HGG]\r\nsurvival_id_test_HGG = [survival_id_from_HGG[i] for i in test_index_HGG]\r\nsurvival_id_tr_HGG = [survival_id_from_HGG[i] for i in tr_index_HGG]\r\n\r\nsurvival_id_dev_LGG = [LGG_name_list[i] for i in dev_index_LGG]\r\nsurvival_id_test_LGG = [LGG_name_list[i] for i in test_index_LGG]\r\nsurvival_id_tr_LGG = [LGG_name_list[i] for i in tr_index_LGG]\r\n\r\nsurvival_age_dev = [survival_age_list[survival_id_list.index(i)] for i in survival_id_dev_HGG]\r\nsurvival_age_test = [survival_age_list[survival_id_list.index(i)] for i in survival_id_test_HGG]\r\nsurvival_age_tr = [survival_age_list[survival_id_list.index(i)] for i in survival_id_tr_HGG]\r\n\r\nsurvival_period_dev = [survival_peroid_list[survival_id_list.index(i)] for i in survival_id_dev_HGG]\r\nsurvival_period_test = [survival_peroid_list[survival_id_list.index(i)] for i in survival_id_test_HGG]\r\nsurvival_period_tr = [survival_peroid_list[survival_id_list.index(i)] for i in survival_id_tr_HGG]\r\n\r\ndata_types = ['flair', 't1', 't1ce', 't2']\r\ndata_types_mean_std_dict = {i: {'mean': 0.0, 'std': 1.0} for i in data_types}\r\n\r\n# calculate mean and std for all data types\r\n\r\n# preserving_ratio = 0.0\r\n# preserving_ratio = 0.01 # 0.118 removed\r\n# preserving_ratio = 0.05 # 0.213 removed\r\n# preserving_ratio = 0.10 # 0.359 removed\r\n\r\n#==================== LOAD ALL IMAGES' PATH AND COMPUTE MEAN/ STD\r\nfor i in data_types:\r\n    data_temp_list = []\r\n    for j in HGG_name_list:\r\n        img_path = os.path.join(HGG_data_path, j, j + '_' + i + '.nii.gz')\r\n        img = nib.load(img_path).get_data()\r\n        data_temp_list.append(img)\r\n\r\n    for j in LGG_name_list:\r\n        img_path = os.path.join(LGG_data_path, j, j + '_' + i + '.nii.gz')\r\n        img = nib.load(img_path).get_data()\r\n        data_temp_list.append(img)\r\n\r\n    data_temp_list = np.asarray(data_temp_list)\r\n    m = np.mean(data_temp_list)\r\n    s = np.std(data_temp_list)\r\n    data_types_mean_std_dict[i]['mean'] = m\r\n    data_types_mean_std_dict[i]['std'] = s\r\ndel data_temp_list\r\nprint(data_types_mean_std_dict)\r\n\r\nwith open(save_dir + 'mean_std_dict.pickle', 'wb') as f:\r\n    pickle.dump(data_types_mean_std_dict, f, protocol=4)\r\n\r\n\r\n##==================== GET NORMALIZE IMAGES\r\nX_train_input = []\r\nX_train_target = []\r\n# X_train_target_whole = [] # 1 2 4\r\n# X_train_target_core = [] # 1 4\r\n# X_train_target_enhance = [] # 4\r\n\r\nX_dev_input = []\r\nX_dev_target = []\r\n# X_dev_target_whole = [] # 1 2 4\r\n# X_dev_target_core = [] # 1 4\r\n# X_dev_target_enhance = [] # 4\r\n\r\nprint(\" HGG Validation\")\r\nfor i in survival_id_dev_HGG:\r\n    all_3d_data = []\r\n    for j in data_types:\r\n        img_path = os.path.join(HGG_data_path, i, i + '_' + j + '.nii.gz')\r\n        img = nib.load(img_path).get_data()\r\n        img = (img - data_types_mean_std_dict[j]['mean']) / data_types_mean_std_dict[j]['std']\r\n        img = img.astype(np.float32)\r\n        all_3d_data.append(img)\r\n\r\n    seg_path = os.path.join(HGG_data_path, i, i + '_seg.nii.gz')\r\n    seg_img = nib.load(seg_path).get_data()\r\n    seg_img = np.transpose(seg_img, (1, 0, 2))\r\n    for j in range(all_3d_data[0].shape[2]):\r\n        combined_array = np.stack((all_3d_data[0][:, :, j], all_3d_data[1][:, :, j], all_3d_data[2][:, :, j], all_3d_data[3][:, :, j]), axis=2)\r\n        combined_array = np.transpose(combined_array, (1, 0, 2))#.tolist()\r\n        combined_array.astype(np.float32)\r\n        X_dev_input.append(combined_array)\r\n\r\n        seg_2d = seg_img[:, :, j]\r\n        # whole = np.zeros_like(seg_2d)\r\n        # core = np.zeros_like(seg_2d)\r\n        # enhance = np.zeros_like(seg_2d)\r\n        # for index, x in np.ndenumerate(seg_2d):\r\n        #     if x == 1:\r\n        #         whole[index] = 1\r\n        #         core[index] = 1\r\n        #     if x == 2:\r\n        #         whole[index] = 1\r\n        #     if x == 4:\r\n        #         whole[index] = 1\r\n        #         core[index] = 1\r\n        #         enhance[index] = 1\r\n        # X_dev_target_whole.append(whole)\r\n        # X_dev_target_core.append(core)\r\n        # X_dev_target_enhance.append(enhance)\r\n        seg_2d.astype(int)\r\n        X_dev_target.append(seg_2d)\r\n    del all_3d_data\r\n    gc.collect()\r\n    print(\"finished {}\".format(i))\r\n\r\nprint(\" LGG Validation\")\r\nfor i in survival_id_dev_LGG:\r\n    all_3d_data = []\r\n    for j in data_types:\r\n        img_path = os.path.join(LGG_data_path, i, i + '_' + j + '.nii.gz')\r\n        img = nib.load(img_path).get_data()\r\n        img = (img - data_types_mean_std_dict[j]['mean']) / data_types_mean_std_dict[j]['std']\r\n        img = img.astype(np.float32)\r\n        all_3d_data.append(img)\r\n\r\n    seg_path = os.path.join(LGG_data_path, i, i + '_seg.nii.gz')\r\n    seg_img = nib.load(seg_path).get_data()\r\n    seg_img = np.transpose(seg_img, (1, 0, 2))\r\n    for j in range(all_3d_data[0].shape[2]):\r\n        combined_array = np.stack((all_3d_data[0][:, :, j], all_3d_data[1][:, :, j], all_3d_data[2][:, :, j], all_3d_data[3][:, :, j]), axis=2)\r\n        combined_array = np.transpose(combined_array, (1, 0, 2))#.tolist()\r\n        combined_array.astype(np.float32)\r\n        X_dev_input.append(combined_array)\r\n\r\n        seg_2d = seg_img[:, :, j]\r\n        # whole = np.zeros_like(seg_2d)\r\n        # core = np.zeros_like(seg_2d)\r\n        # enhance = np.zeros_like(seg_2d)\r\n        # for index, x in np.ndenumerate(seg_2d):\r\n        #     if x == 1:\r\n        #         whole[index] = 1\r\n        #         core[index] = 1\r\n        #     if x == 2:\r\n        #         whole[index] = 1\r\n        #     if x == 4:\r\n        #         whole[index] = 1\r\n        #         core[index] = 1\r\n        #         enhance[index] = 1\r\n        # X_dev_target_whole.append(whole)\r\n        # X_dev_target_core.append(core)\r\n        # X_dev_target_enhance.append(enhance)\r\n        seg_2d.astype(int)\r\n        X_dev_target.append(seg_2d)\r\n    del all_3d_data\r\n    gc.collect()\r\n    print(\"finished {}\".format(i))\r\n\r\nX_dev_input = np.asarray(X_dev_input, dtype=np.float32)\r\nX_dev_target = np.asarray(X_dev_target)#, dtype=np.float32)\r\n# print(X_dev_input.shape)\r\n# print(X_dev_target.shape)\r\n\r\n# with open(save_dir + 'dev_input.pickle', 'wb') as f:\r\n#     pickle.dump(X_dev_input, f, protocol=4)\r\n# with open(save_dir + 'dev_target.pickle', 'wb') as f:\r\n#     pickle.dump(X_dev_target, f, protocol=4)\r\n\r\n# del X_dev_input, X_dev_target\r\n\r\nprint(\" HGG Train\")\r\nfor i in survival_id_tr_HGG:\r\n    all_3d_data = []\r\n    for j in data_types:\r\n        img_path = os.path.join(HGG_data_path, i, i + '_' + j + '.nii.gz')\r\n        img = nib.load(img_path).get_data()\r\n        img = (img - data_types_mean_std_dict[j]['mean']) / data_types_mean_std_dict[j]['std']\r\n        img = img.astype(np.float32)\r\n        all_3d_data.append(img)\r\n\r\n    seg_path = os.path.join(HGG_data_path, i, i + '_seg.nii.gz')\r\n    seg_img = nib.load(seg_path).get_data()\r\n    seg_img = np.transpose(seg_img, (1, 0, 2))\r\n    for j in range(all_3d_data[0].shape[2]):\r\n        combined_array = np.stack((all_3d_data[0][:, :, j], all_3d_data[1][:, :, j], all_3d_data[2][:, :, j], all_3d_data[3][:, :, j]), axis=2)\r\n        combined_array = np.transpose(combined_array, (1, 0, 2))#.tolist()\r\n        combined_array.astype(np.float32)\r\n        X_train_input.append(combined_array)\r\n\r\n        seg_2d = seg_img[:, :, j]\r\n        # whole = np.zeros_like(seg_2d)\r\n        # core = np.zeros_like(seg_2d)\r\n        # enhance = np.zeros_like(seg_2d)\r\n        # for index, x in np.ndenumerate(seg_2d):\r\n        #     if x == 1:\r\n        #         whole[index] = 1\r\n        #         core[index] = 1\r\n        #     if x == 2:\r\n        #         whole[index] = 1\r\n        #     if x == 4:\r\n        #         whole[index] = 1\r\n        #         core[index] = 1\r\n        #         enhance[index] = 1\r\n        # X_train_target_whole.append(whole)\r\n        # X_train_target_core.append(core)\r\n        # X_train_target_enhance.append(enhance)\r\n        seg_2d.astype(int)\r\n        X_train_target.append(seg_2d)\r\n    del all_3d_data\r\n    print(\"finished {}\".format(i))\r\n    # print(len(X_train_target))\r\n\r\n\r\nprint(\" LGG Train\")\r\nfor i in survival_id_tr_LGG:\r\n    all_3d_data = []\r\n    for j in data_types:\r\n        img_path = os.path.join(LGG_data_path, i, i + '_' + j + '.nii.gz')\r\n        img = nib.load(img_path).get_data()\r\n        img = (img - data_types_mean_std_dict[j]['mean']) / data_types_mean_std_dict[j]['std']\r\n        img = img.astype(np.float32)\r\n        all_3d_data.append(img)\r\n\r\n    seg_path = os.path.join(LGG_data_path, i, i + '_seg.nii.gz')\r\n    seg_img = nib.load(seg_path).get_data()\r\n    seg_img = np.transpose(seg_img, (1, 0, 2))\r\n    for j in range(all_3d_data[0].shape[2]):\r\n        combined_array = np.stack((all_3d_data[0][:, :, j], all_3d_data[1][:, :, j], all_3d_data[2][:, :, j], all_3d_data[3][:, :, j]), axis=2)\r\n        combined_array = np.transpose(combined_array, (1, 0, 2))#.tolist()\r\n        combined_array.astype(np.float32)\r\n        X_train_input.append(combined_array)\r\n\r\n        seg_2d = seg_img[:, :, j]\r\n        # whole = np.zeros_like(seg_2d)\r\n        # core = np.zeros_like(seg_2d)\r\n        # enhance = np.zeros_like(seg_2d)\r\n        # for index, x in np.ndenumerate(seg_2d):\r\n        #     if x == 1:\r\n        #         whole[index] = 1\r\n        #         core[index] = 1\r\n        #     if x == 2:\r\n        #         whole[index] = 1\r\n        #     if x == 4:\r\n        #         whole[index] = 1\r\n        #         core[index] = 1\r\n        #         enhance[index] = 1\r\n        # X_train_target_whole.append(whole)\r\n        # X_train_target_core.append(core)\r\n        # X_train_target_enhance.append(enhance)\r\n        seg_2d.astype(int)\r\n        X_train_target.append(seg_2d)\r\n    del all_3d_data\r\n    print(\"finished {}\".format(i))\r\n\r\nX_train_input = np.asarray(X_train_input, dtype=np.float32)\r\nX_train_target = np.asarray(X_train_target)#, dtype=np.float32)\r\n# print(X_train_input.shape)\r\n# print(X_train_target.shape)\r\n\r\n# with open(save_dir + 'train_input.pickle', 'wb') as f:\r\n#     pickle.dump(X_train_input, f, protocol=4)\r\n# with open(save_dir + 'train_target.pickle', 'wb') as f:\r\n#     pickle.dump(X_train_target, f, protocol=4)\r\n\r\n\r\n\r\n# X_train_target_whole = np.asarray(X_train_target_whole)\r\n# X_train_target_core = np.asarray(X_train_target_core)\r\n# X_train_target_enhance = np.asarray(X_train_target_enhance)\r\n\r\n\r\n# X_dev_target_whole = np.asarray(X_dev_target_whole)\r\n# X_dev_target_core = np.asarray(X_dev_target_core)\r\n# X_dev_target_enhance = np.asarray(X_dev_target_enhance)\r\n\r\n\r\n# print(X_train_target_whole.shape)\r\n# print(X_train_target_core.shape)\r\n# print(X_train_target_enhance.shape)\r\n\r\n# print(X_dev_target_whole.shape)\r\n# print(X_dev_target_core.shape)\r\n# print(X_dev_target_enhance.shape)\r\n\r\n\r\n\r\n# with open(save_dir + 'train_target_whole.pickle', 'wb') as f:\r\n#     pickle.dump(X_train_target_whole, f, protocol=4)\r\n\r\n# with open(save_dir + 'train_target_core.pickle', 'wb') as f:\r\n#     pickle.dump(X_train_target_core, f, protocol=4)\r\n\r\n# with open(save_dir + 'train_target_enhance.pickle', 'wb') as f:\r\n#     pickle.dump(X_train_target_enhance, f, protocol=4)\r\n\r\n# with open(save_dir + 'dev_target_whole.pickle', 'wb') as f:\r\n#     pickle.dump(X_dev_target_whole, f, protocol=4)\r\n\r\n# with open(save_dir + 'dev_target_core.pickle', 'wb') as f:\r\n#     pickle.dump(X_dev_target_core, f, protocol=4)\r\n\r\n# with open(save_dir + 'dev_target_enhance.pickle', 'wb') as f:\r\n#     pickle.dump(X_dev_target_enhance, f, protocol=4)\r\n"
  },
  {
    "path": "train.py",
    "content": "#! /usr/bin/python\r\n# -*- coding: utf8 -*-\r\n\r\nimport tensorflow as tf\r\nimport tensorlayer as tl\r\nimport numpy as np\r\nimport os, time, model\r\n\r\ndef distort_imgs(data):\n    \"\"\" data augumentation \"\"\"\n    x1, x2, x3, x4, y = data\n    # x1, x2, x3, x4, y = tl.prepro.flip_axis_multi([x1, x2, x3, x4, y],  # previous without this, hard-dice=83.7\n    #                         axis=0, is_random=True) # up down\n    x1, x2, x3, x4, y = tl.prepro.flip_axis_multi([x1, x2, x3, x4, y],\n                            axis=1, is_random=True) # left right\n    x1, x2, x3, x4, y = tl.prepro.elastic_transform_multi([x1, x2, x3, x4, y],\n                            alpha=720, sigma=24, is_random=True)\n    x1, x2, x3, x4, y = tl.prepro.rotation_multi([x1, x2, x3, x4, y], rg=20,\n                            is_random=True, fill_mode='constant') # nearest, constant\n    x1, x2, x3, x4, y = tl.prepro.shift_multi([x1, x2, x3, x4, y], wrg=0.10,\n                            hrg=0.10, is_random=True, fill_mode='constant')\n    x1, x2, x3, x4, y = tl.prepro.shear_multi([x1, x2, x3, x4, y], 0.05,\n                            is_random=True, fill_mode='constant')\n    x1, x2, x3, x4, y = tl.prepro.zoom_multi([x1, x2, x3, x4, y],\n                            zoom_range=[0.9, 1.1], is_random=True,\n                            fill_mode='constant')\n    return x1, x2, x3, x4, y\n\ndef vis_imgs(X, y, path):\r\n    \"\"\" show one slice \"\"\"\r\n    if y.ndim == 2:\r\n        y = y[:,:,np.newaxis]\r\n    assert X.ndim == 3\r\n    tl.vis.save_images(np.asarray([X[:,:,0,np.newaxis],\r\n        X[:,:,1,np.newaxis], X[:,:,2,np.newaxis],\r\n        X[:,:,3,np.newaxis], y]), size=(1, 5),\r\n        image_path=path)\r\n\r\ndef vis_imgs2(X, y_, y, path):\r\n    \"\"\" show one slice with target \"\"\"\r\n    if y.ndim == 2:\r\n        y = y[:,:,np.newaxis]\r\n    if y_.ndim == 2:\r\n        y_ = y_[:,:,np.newaxis]\r\n    assert X.ndim == 3\r\n    tl.vis.save_images(np.asarray([X[:,:,0,np.newaxis],\r\n        X[:,:,1,np.newaxis], X[:,:,2,np.newaxis],\r\n        X[:,:,3,np.newaxis], y_, y]), size=(1, 6),\r\n        image_path=path)\r\n\r\ndef main(task='all'):\r\n    ## Create folder to save trained model and result images\r\n    save_dir = \"checkpoint\"\r\n    tl.files.exists_or_mkdir(save_dir)\r\n    tl.files.exists_or_mkdir(\"samples/{}\".format(task))\r\n\r\n    ###======================== LOAD DATA ===================================###\r\n    ## by importing this, you can load a training set and a validation set.\r\n    # you will get X_train_input, X_train_target, X_dev_input and X_dev_target\r\n    # there are 4 labels in targets:\r\n    # Label 0: background\r\n    # Label 1: necrotic and non-enhancing tumor\r\n    # Label 2: edema\r\n    # Label 4: enhancing tumor\r\n    import prepare_data_with_valid as dataset\r\n    X_train = dataset.X_train_input\r\n    y_train = dataset.X_train_target[:,:,:,np.newaxis]\r\n    X_test = dataset.X_dev_input\r\n    y_test = dataset.X_dev_target[:,:,:,np.newaxis]\r\n\r\n    if task == 'all':\r\n        y_train = (y_train > 0).astype(int)\r\n        y_test = (y_test > 0).astype(int)\r\n    elif task == 'necrotic':\n        y_train = (y_train == 1).astype(int)\n        y_test = (y_test == 1).astype(int)\n    elif task == 'edema':\n        y_train = (y_train == 2).astype(int)\n        y_test = (y_test == 2).astype(int)\n    elif task == 'enhance':\n        y_train = (y_train == 4).astype(int)\n        y_test = (y_test == 4).astype(int)\n    else:\n        exit(\"Unknow task %s\" % task)\n\n    ###======================== HYPER-PARAMETERS ============================###\n    batch_size = 10\n    lr = 0.0001 \n    # lr_decay = 0.5\n    # decay_every = 100\n    beta1 = 0.9\n    n_epoch = 100\n    print_freq_step = 100\n\n    ###======================== SHOW DATA ===================================###\n    # show one slice\n    X = np.asarray(X_train[80])\n    y = np.asarray(y_train[80])\n    # print(X.shape, X.min(), X.max()) # (240, 240, 4) -0.380588 2.62761\n    # print(y.shape, y.min(), y.max()) # (240, 240, 1) 0 1\n    nw, nh, nz = X.shape\n    vis_imgs(X, y, 'samples/{}/_train_im.png'.format(task))\n    # show data augumentation results\n    for i in range(10):\n        x_flair, x_t1, x_t1ce, x_t2, label = distort_imgs([X[:,:,0,np.newaxis], X[:,:,1,np.newaxis],\n                X[:,:,2,np.newaxis], X[:,:,3,np.newaxis], y])#[:,:,np.newaxis]])\n        # print(x_flair.shape, x_t1.shape, x_t1ce.shape, x_t2.shape, label.shape) # (240, 240, 1) (240, 240, 1) (240, 240, 1) (240, 240, 1) (240, 240, 1)\n        X_dis = np.concatenate((x_flair, x_t1, x_t1ce, x_t2), axis=2)\n        # print(X_dis.shape, X_dis.min(), X_dis.max()) # (240, 240, 4) -0.380588233471 2.62376139209\n        vis_imgs(X_dis, label, 'samples/{}/_train_im_aug{}.png'.format(task, i))\n\r\n    with tf.device('/cpu:0'):\r\n        sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))\r\n        with tf.device('/gpu:0'): #<- remove it if you train on CPU or other GPU\r\n            ###======================== DEFIINE MODEL =======================###\r\n            ## nz is 4 as we input all Flair, T1, T1c and T2.\r\n            t_image = tf.placeholder('float32', [batch_size, nw, nh, nz], name='input_image')\r\n            ## labels are either 0 or 1\r\n            t_seg = tf.placeholder('float32', [batch_size, nw, nh, 1], name='target_segment')\r\n            ## train inference\r\n            net = model.u_net(t_image, is_train=True, reuse=False, n_out=1)\r\n            ## test inference\r\n            net_test = model.u_net(t_image, is_train=False, reuse=True, n_out=1)\r\n\r\n            ###======================== DEFINE LOSS =========================###\r\n            ## train losses\r\n            out_seg = net.outputs\r\n            dice_loss = 1 - tl.cost.dice_coe(out_seg, t_seg, axis=[0,1,2,3])#, 'jaccard', epsilon=1e-5)\n            iou_loss = tl.cost.iou_coe(out_seg, t_seg, axis=[0,1,2,3])\r\n            dice_hard = tl.cost.dice_hard_coe(out_seg, t_seg, axis=[0,1,2,3])\r\n            loss = dice_loss\r\n\r\n            ## test losses\r\n            test_out_seg = net_test.outputs\r\n            test_dice_loss = 1 - tl.cost.dice_coe(test_out_seg, t_seg, axis=[0,1,2,3])#, 'jaccard', epsilon=1e-5)\n            test_iou_loss = tl.cost.iou_coe(test_out_seg, t_seg, axis=[0,1,2,3])\r\n            test_dice_hard = tl.cost.dice_hard_coe(test_out_seg, t_seg, axis=[0,1,2,3])\r\n\r\n        ###======================== DEFINE TRAIN OPTS =======================###\n        t_vars = tl.layers.get_variables_with_name('u_net', True, True)\n        with tf.device('/gpu:0'):\n            with tf.variable_scope('learning_rate'):\n                lr_v = tf.Variable(lr, trainable=False)\n            train_op = tf.train.AdamOptimizer(lr_v, beta1=beta1).minimize(loss, var_list=t_vars)\n\r\n        ###======================== LOAD MODEL ==============================###\n        tl.layers.initialize_global_variables(sess)\n        ## load existing model if possible\n        tl.files.load_and_assign_npz(sess=sess, name=save_dir+'/u_net_{}.npz'.format(task), network=net)\n\r\n        ###======================== TRAINING ================================###\r\n    for epoch in range(0, n_epoch+1):\n        epoch_time = time.time()\n        ## update decay learning rate at the beginning of a epoch\n        # if epoch !=0 and (epoch % decay_every == 0):\n        #     new_lr_decay = lr_decay ** (epoch // decay_every)\n        #     sess.run(tf.assign(lr_v, lr * new_lr_decay))\n        #     log = \" ** new learning rate: %f\" % (lr * new_lr_decay)\n        #     print(log)\n        # elif epoch == 0:\n        #     sess.run(tf.assign(lr_v, lr))\n        #     log = \" ** init lr: %f  decay_every_epoch: %d, lr_decay: %f\" % (lr, decay_every, lr_decay)\n        #     print(log)\n\r\n        total_dice, total_iou, total_dice_hard, n_batch = 0, 0, 0, 0\n        for batch in tl.iterate.minibatches(inputs=X_train, targets=y_train,\n                                    batch_size=batch_size, shuffle=True):\n            images, labels = batch\n            step_time = time.time()\n            ## data augumentation for a batch of Flair, T1, T1c, T2 images\n            # and label maps synchronously.\n            data = tl.prepro.threading_data([_ for _ in zip(images[:,:,:,0, np.newaxis],\n                    images[:,:,:,1, np.newaxis], images[:,:,:,2, np.newaxis],\n                    images[:,:,:,3, np.newaxis], labels)],\n                    fn=distort_imgs) # (10, 5, 240, 240, 1)\n            b_images = data[:,0:4,:,:,:]  # (10, 4, 240, 240, 1)\n            b_labels = data[:,4,:,:,:]\n            b_images = b_images.transpose((0,2,3,1,4))\n            b_images.shape = (batch_size, nw, nh, nz)\n\r\n            ## update network\n            _, _dice, _iou, _diceh, out = sess.run([train_op,\n                    dice_loss, iou_loss, dice_hard, net.outputs],\n                    {t_image: b_images, t_seg: b_labels})\n            total_dice += _dice; total_iou += _iou; total_dice_hard += _diceh\n            n_batch += 1\n\r\n            ## you can show the predition here:\n            # vis_imgs2(b_images[0], b_labels[0], out[0], \"samples/{}/_tmp.png\".format(task))\n            # exit()\n\r\n            # if _dice == 1: # DEBUG\n            #     print(\"DEBUG\")\n            #     vis_imgs2(b_images[0], b_labels[0], out[0], \"samples/{}/_debug.png\".format(task))\n\r\n            if n_batch % print_freq_step == 0:\n                print(\"Epoch %d step %d 1-dice: %f hard-dice: %f iou: %f took %fs (2d with distortion)\"\n                % (epoch, n_batch, _dice, _diceh, _iou, time.time()-step_time))\n\r\n            ## check model fail\n            if np.isnan(_dice):\n                exit(\" ** NaN loss found during training, stop training\")\n            if np.isnan(out).any():\n                exit(\" ** NaN found in output images during training, stop training\")\n\r\n        print(\" ** Epoch [%d/%d] train 1-dice: %f hard-dice: %f iou: %f took %fs (2d with distortion)\" %\n                (epoch, n_epoch, total_dice/n_batch, total_dice_hard/n_batch, total_iou/n_batch, time.time()-epoch_time))\n\r\n        ## save a predition of training set\n        for i in range(batch_size):\n            if np.max(b_images[i]) > 0:\n                vis_imgs2(b_images[i], b_labels[i], out[i], \"samples/{}/train_{}.png\".format(task, epoch))\n                break\n            elif i == batch_size-1:\n                vis_imgs2(b_images[i], b_labels[i], out[i], \"samples/{}/train_{}.png\".format(task, epoch))\n\r\n        ###======================== EVALUATION ==========================###\n        total_dice, total_iou, total_dice_hard, n_batch = 0, 0, 0, 0\n        for batch in tl.iterate.minibatches(inputs=X_test, targets=y_test,\n                                        batch_size=batch_size, shuffle=True):\n            b_images, b_labels = batch\n            _dice, _iou, _diceh, out = sess.run([test_dice_loss,\n                    test_iou_loss, test_dice_hard, net_test.outputs],\n                    {t_image: b_images, t_seg: b_labels})\n            total_dice += _dice; total_iou += _iou; total_dice_hard += _diceh\n            n_batch += 1\n\r\n        print(\" **\"+\" \"*17+\"test 1-dice: %f hard-dice: %f iou: %f (2d no distortion)\" %\n                (total_dice/n_batch, total_dice_hard/n_batch, total_iou/n_batch))\n        print(\" task: {}\".format(task))\n        ## save a predition of test set\n        for i in range(batch_size):\n            if np.max(b_images[i]) > 0:\n                vis_imgs2(b_images[i], b_labels[i], out[i], \"samples/{}/test_{}.png\".format(task, epoch))\n                break\n            elif i == batch_size-1:\n                vis_imgs2(b_images[i], b_labels[i], out[i], \"samples/{}/test_{}.png\".format(task, epoch))\n\r\n        ###======================== SAVE MODEL ==========================###\n        tl.files.save_npz(net.all_params, name=save_dir+'/u_net_{}.npz'.format(task), sess=sess)\n\r\nif __name__ == \"__main__\":\r\n    import argparse\r\n    parser = argparse.ArgumentParser()\r\n\r\n    parser.add_argument('--task', type=str, default='all', help='all, necrotic, edema, enhance')\r\n\r\n    args = parser.parse_args()\r\n\r\n    main(args.task)\r\n"
  }
]