[
  {
    "path": "OE_eval.py",
    "content": "from collections import OrderedDict\nfrom options.train_options import TrainOptions\nfrom data import CreateDataLoader\nfrom models import create_model\nfrom PIL import Image\nimport time\nimport math\nfrom sklearn.metrics import balanced_accuracy_score, mean_squared_error\nfrom skimage.color import rgb2lab\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport os\nimport shutil\n\nimport logging\n\n\nopt = TrainOptions().parse()\n#train_data_loader = CreateDataLoader(opt)\n#train_dataset = train_data_loader.load_data()\n#train_dataset_size = len(train_data_loader)\n\nopt.phase = 'test/test_'\nopt.batch_size = 1\nopt.serial_batches = True\nopt.isTrain = False\ntest_data_loader = CreateDataLoader(opt)\ntest_dataset = test_data_loader.load_data()\ntest_dataset_size = len(test_data_loader)\n\nmodel = create_model(opt)\nmodel.setup(opt)\n\n\n# Set logger\nmsg = []\nlogger = logging.getLogger('%s' % opt.name)\nlogger.setLevel(logging.INFO)\nif not os.path.isdir(model.save_dir):\n  msg.append('%s not exist, make it' % model.save_dir)\n  os.mkdir(opt.dir)\nlog_file_path = os.path.join(model.save_dir, 'log.log')\nif os.path.isfile(log_file_path):\n  target_path = log_file_path + '.%s' % time.strftime(\"%Y%m%d%H%M%S\")\n  msg.append('Log file exists, backup to %s' % target_path)\n  shutil.move(log_file_path, target_path)\nfh = logging.FileHandler(log_file_path)\nfh.setLevel(logging.INFO)\nch = logging.StreamHandler()\nch.setLevel(logging.INFO)\n\nformatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')\nfh.setFormatter(formatter)\nch.setFormatter(formatter)\nlogger.addHandler(fh)\nlogger.addHandler(ch)\n\n\ndef tensor2im(input_image, imtype=np.uint8):\n    \"\"\"\"Converts a Tensor array into a numpy image array.\n    Parameters:\n        input_image (tensor) --  the input image tensor array\n        imtype (type)        --  the desired type of the converted numpy array\n    \"\"\"\n    if not isinstance(input_image, np.ndarray):\n        if isinstance(input_image, torch.Tensor):  # get the data from a variable\n            image_tensor = input_image.data\n        else:\n            return input_image\n        image_numpy = image_tensor[0].cpu().float().numpy()  # convert it into a numpy array\n        if image_numpy.shape[0] == 1:  # grayscale to RGB\n            image_numpy = np.tile(image_numpy, (3, 1, 1))\n            # image_numpy = image_numpy.convert('L')\n        image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0  # post-processing: tranpose and scaling\n    else:  # if it is a numpy array, do nothing\n        image_numpy = input_image\n    return np.clip(image_numpy, 0, 255).astype(imtype)\n\n\ndef calc_RMSE(real_img, fake_img):\n    # convert to LAB color space\n    real_lab = rgb2lab(real_img)\n    fake_lab = rgb2lab(fake_img)\n    return real_lab - fake_lab\n\nmodel.eval()\nmodel.load_networks('latest')\neval_shadow_rmse = 0\neval_nonshadow_rmse = 0\neval_rmse = 0\neval_loss = 0\nfor i, data in enumerate(test_dataset):\n    iter_start_time = time.time()\n    # total_steps += opt.batch_size\n    # epoch_iter += opt.batch_size\n    model.set_input(data)\n    model.forward()\n\n    # evaluation refers to matlab code\n    # diff = calc_RMSE(tensor2im(model.shadowfree_img), tensor2im(model.final))\n    # mask = model.shadow_mask.data[0].cpu().float().numpy()[..., None][0, ...]\n    #\n    # if mask.sum() < 2:\n    #     continue\n    # shadow_rmse = np.sqrt(1.0 * (np.power(diff, 2) * mask).sum(axis=(0, 1)) / mask.sum())\n    # nonshadow_rmse = np.sqrt(1.0 * (np.power(diff, 2) * (1 - mask)).sum(axis=(0, 1)) / (1 - mask).sum())\n    # whole_rmse = np.sqrt(np.power(diff, 2).mean(axis=(0, 1)))\n    #\n    # eval_shadow_rmse += shadow_rmse.sum()\n    # eval_nonshadow_rmse += nonshadow_rmse.sum()\n    # eval_rmse += whole_rmse.sum()\n\n    model.netR.zero_grad()\n\n    model.vis(0, i, opt.name, True)\n\nlogger.info('[Eval] [Epoch] %d | rmse : %.3f | shadow_rmse : %.3f | nonshadow_rmse : %.3f' % \n      (0, eval_rmse / len(test_dataset),\n      eval_shadow_rmse / len(test_dataset), eval_nonshadow_rmse / len(test_dataset)))\n\n"
  },
  {
    "path": "OE_eval.sh",
    "content": "#!/bin/bash\n\nMYGIT=/mnt/nvme/zcq/git\nREPO_PATH=/home/fulan/shadow_removal/exposure-fusion-shadow-removal\n# DATA_PATH=${MYGIT}/shadow_removal/data/SRD\n# datasetmode=srd\n\nDATA_PATH=./ISTD+\ndatasetmode=expo_param\n\n\nbatchs=8\nn=5\nks=3\nrks=3\nversion='fixed5-1-boundary-loss'\n\nlr_policy=lambda\nlr_decay_iters=50\noptimizer=adam\nshadow_loss=5.0\n\ntv_loss=0.0\ngrad_loss=0.1\npgrad_loss=0.0\n\ngpus=0\n\n\nlr=0.0001\nloadSize=256\nfineSize=256\nL1=10\nmodel=Refine\ncheckpoint=${REPO_PATH}/log\ndataroot=${DATA_PATH}\nNAME=\"M${model}_${datasetmode}_b${batchs}_lr${lr}_L1${L1}_n${n}_ks${ks}_v${version}_${optimizer}_${lr_policy}_${shadow_loss}_rks${rks}_TV${tv_loss}G${grad_loss}GP${pgrad_loss}\"\nOTHER=\"--save_epoch_freq 100 --niter 50 --niter_decay 350\"\n\n\ntrainmask=${dataroot}'/train_NOTUSE'\nCMD=\"python -u ./OE_eval.py --loadSize ${loadSize} \\\n    --randomSize\n    --name ${NAME} \\\n    --dataroot  ${dataroot}\\\n    --checkpoints_dir ${checkpoint} \\\n    --fineSize $fineSize --model $model \\\n    --batch_size $batchs \\\n    --randomSize --keep_ratio --phase train_  --gpu_ids ${gpus} --lr ${lr} \\\n    --lambda_L1 ${L1} --num_threads 16 \\\n    --dataset_mode $datasetmode\\\n    --mask_train $trainmask --optimizer ${optimizer} \\\n    --n ${n} --ks ${ks} --lr_policy ${lr_policy} --lr_decay_iters ${lr_decay_iters} \\\n    --shadow_loss ${shadow_loss} --tv_loss ${tv_loss} --grad_loss ${grad_loss} --pgrad_loss ${pgrad_loss} \\\n    $OTHER\n\"\necho $CMD\neval $CMD # >> ${checkpoint}/${NAME}.log 2>&1 &\n\n"
  },
  {
    "path": "OE_train.py",
    "content": "from collections import OrderedDict\nfrom options.train_options import TrainOptions\nfrom data import CreateDataLoader\nfrom models import create_model\nfrom PIL import Image\nimport time\nimport math\nfrom sklearn.metrics import balanced_accuracy_score, mean_squared_error\nfrom skimage.color import rgb2lab\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport os\nimport shutil\n\nimport logging\n\nopt = TrainOptions().parse()\n\nopt.phase = 'train/train_'\nopt.serial_batches = False\ntrain_data_loader = CreateDataLoader(opt)\ntrain_dataset = train_data_loader.load_data()\ntrain_dataset_size = len(train_data_loader)\n\nopt.phase = 'test/test_'\nopt.batch_size = 1\nopt.serial_batches = True\ntest_data_loader = CreateDataLoader(opt)\ntest_dataset = test_data_loader.load_data()\ntest_dataset_size = len(test_data_loader)\n\nmodel = create_model(opt)\nmodel.setup(opt)\nif opt.load_dir and opt.load_dir != 'None':\n    print('load fusion net from:', opt.load_dir)\n    model.load_networks('latest', opt.load_dir)\n\n# Set logger\nmsg = []\nlogger = logging.getLogger('%s' % opt.name)\nlogger.setLevel(logging.INFO)\nif not os.path.isdir(model.save_dir):\n    msg.append('%s not exist, make it' % model.save_dir)\n    os.mkdir(args.dir)\nlog_file_path = os.path.join(model.save_dir, 'log.log')\nif os.path.isfile(log_file_path):\n    target_path = log_file_path + '.%s' % time.strftime(\"%Y%m%d%H%M%S\")\n    msg.append('Log file exists, backup to %s' % target_path)\n    shutil.move(log_file_path, target_path)\nfh = logging.FileHandler(log_file_path)\nfh.setLevel(logging.INFO)\nch = logging.StreamHandler()\nch.setLevel(logging.INFO)\n\nformatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')\nfh.setFormatter(formatter)\nch.setFormatter(formatter)\nlogger.addHandler(fh)\nlogger.addHandler(ch)\n\n\ndef tensor2im(input_image, imtype=np.uint8):\n    \"\"\"\"Converts a Tensor array into a numpy image array.\n    Parameters:\n        input_image (tensor) --  the input image tensor array\n        imtype (type)        --  the desired type of the converted numpy array\n    \"\"\"\n    if not isinstance(input_image, np.ndarray):\n        if isinstance(input_image, torch.Tensor):  # get the data from a variable\n            image_tensor = input_image.data\n        else:\n            return input_image\n        image_numpy = image_tensor[0].cpu().float().numpy()  # convert it into a numpy array\n        if image_numpy.shape[0] == 1:  # grayscale to RGB\n            image_numpy = np.tile(image_numpy, (3, 1, 1))\n            # image_numpy = image_numpy.convert('L')\n        image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0  # post-processing: tranpose and scaling\n    else:  # if it is a numpy array, do nothing\n        image_numpy = input_image\n    return np.clip(image_numpy, 0, 255).astype(imtype)\n\n\ndef calc_RMSE(real_img, fake_img):\n    # convert to LAB color space\n    real_lab = rgb2lab(real_img)\n    fake_lab = rgb2lab(fake_img)\n    return real_lab - fake_lab\n\n\nmse_criterion = nn.MSELoss()\ntotal_steps = 0\n\nfor epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1):\n    epoch_start_time = time.time()\n    epoch_iter = 0\n    model.epoch = epoch\n\n    model.train()\n    for i, data in enumerate(train_dataset):\n        iter_start_time = time.time()\n        total_steps += 1\n        model.set_input(data)\n        # model.zero_grad()\n        model.optimize_parameters()\n\n        if total_steps % 10 == 0:\n            # Do log\n            train_loss = model.loss.detach().item()\n            train_mse = mse_criterion(model.final, model.shadowfree_img).detach().item()\n            logger.info('[Train] [Epoch] %d [Steps] %d | loss : %.3f' % (epoch, total_steps, train_loss))\n\n    if (epoch and epoch % 30 == 0) or (epoch == opt.niter + opt.niter_decay):\n        model.eval()\n        eval_shadow_rmse = 0\n        eval_nonshadow_rmse = 0\n        eval_rmse = 0\n        eval_loss = 0\n        for i, data in enumerate(test_dataset):\n            iter_start_time = time.time()\n            total_steps += opt.batch_size\n            epoch_iter += opt.batch_size\n            model.set_input(data)\n            model.forward()\n\n            eval_loss += model.loss.detach().item()\n            diff = calc_RMSE(tensor2im(model.shadowfree_img), tensor2im(model.final))\n            mask = model.shadow_mask.data[0].cpu().float().numpy()[..., None][0, ...]\n\n            if mask.sum() < 2:\n                continue\n            shadow_rmse = np.sqrt(1.0 * (np.power(diff, 2) * mask).sum(axis=(0, 1)) / mask.sum())\n            nonshadow_rmse = np.sqrt(1.0 * (np.power(diff, 2) * (1 - mask)).sum(axis=(0, 1)) / (1 - mask).sum())\n            whole_rmse = np.sqrt(np.power(diff, 2).mean(axis=(0, 1)))\n\n            # (256, 256, 3) (3,) (3,) (256, 256, 1) (256, 256, 3)\n            # print(diff.shape, whole_rmse.shape, shadow_rmse.shape, mask.shape, (diff * mask).shape)\n\n            eval_shadow_rmse += shadow_rmse.sum()\n            eval_nonshadow_rmse += nonshadow_rmse.sum()\n            eval_rmse += whole_rmse.sum()\n\n            model.zero_grad()\n\n            if i % 20 == 0:\n                model.vis(epoch, i)\n\n        logger.info('[Eval] [Epoch] %d | loss : %.3f | rmse : %.3f | shadow_rmse : %.3f | nonshadow_rmse : %.3f' %\n                    (epoch, eval_loss / len(test_dataset), eval_rmse / len(test_dataset),\n                     eval_shadow_rmse / len(test_dataset), eval_nonshadow_rmse / len(test_dataset)))\n\n    if epoch and epoch % 50 == 0 or (epoch == opt.niter + opt.niter_decay):\n        logger.info('saving the model at the end of epoch %d, iters %d' %\n                    (epoch, total_steps))\n        model.save_networks('latest')\n        model.save_networks(epoch)\n\n    spt_time = time.time() - epoch_start_time\n    lft_time = (opt.niter + opt.niter_decay - epoch) * spt_time\n    logger.info('End of epoch %d / %d | Time Taken: %d sec | eta %.2f' %\n                (epoch, opt.niter + opt.niter_decay, spt_time, lft_time / 3600.0))\n    model.update_learning_rate()\n"
  },
  {
    "path": "OE_train.sh",
    "content": "#!/bin/bash\n\n\nMYGIT=/mnt/nvme/zcq/git\nREPO_PATH=/home/fulan/shadow_removal\n# DATA_PATH=${MYGIT}/shadow_removal/data/SRD\n# datasetmode=srd\n\nDATA_PATH=/home/fulan/ShadowRemoval/data\ndatasetmode=expo_param\n\n\nbatchs=4\nn=5\nks=3\nrks=3\nversion='fixed5-1-loss'\n\nlr_policy=lambda\nlr_decay_iters=50\noptimizer=adam\nshadow_loss=10.0\n\ntv_loss=0\ngrad_loss=0.0\npgrad_loss=0.1\n\ngpus=0\n\n\nlr=0.0001\nloadSize=256\nfineSize=256\nL1=10\nmodel=Fusion\nload_dir=None\ncheckpoint=${REPO_PATH}/log\ndataroot=${DATA_PATH}\nNAME=\"M${model}_${datasetmode}_b${batchs}_lr${lr}_L1${L1}_n${n}_ks${ks}_v${version}_${optimizer}_${lr_policy}_${shadow_loss}_TV${tv_loss}G${grad_loss}PG${pgrad_loss}\"\nOTHER=\"--save_epoch_freq 100 --niter 50 --niter_decay 350\"\n\n\ntrainmask=${dataroot}'/train_NOTUSE'\nCMD=\"python -u ./OE_train.py --loadSize ${loadSize} \\\n    --randomSize\n    --name ${NAME} \\\n    --dataroot  ${dataroot}\\\n    --checkpoints_dir ${checkpoint} \\\n    --fineSize $fineSize --model $model \\\n    --batch_size $batchs \\\n    --randomSize --keep_ratio --phase train_  --gpu_ids ${gpus} --lr ${lr} \\\n    --lambda_L1 ${L1} --num_threads 16 \\\n    --dataset_mode $datasetmode\\\n    --mask_train $trainmask --optimizer ${optimizer} \\\n    --n ${n} --ks ${ks} --lr_policy ${lr_policy} --lr_decay_iters ${lr_decay_iters} \\\n    --shadow_loss ${shadow_loss} --rks ${rks} --tv_loss ${tv_loss} --grad_loss ${grad_loss} --pgrad_loss ${pgrad_loss} \\\n    --load_dir ${load_dir} \\\n    $OTHER\n\"\necho $CMD\neval $CMD # >> ${checkpoint}/${NAME}.log 2>&1 &\n\n"
  },
  {
    "path": "README.md",
    "content": "# Auto-exposure fusion for single-image shadow removal\nWe propose a new method for effective shadow removal by regarding it as an exposure fusion problem. Please refer to the paper for details: https://openaccess.thecvf.com/content/CVPR2021/papers/Fu_Auto-Exposure_Fusion_for_Single-Image_Shadow_Removal_CVPR_2021_paper.pdf.\n\n![Framework](./images/framework.png)\n\n## Dataset\n\n- ISTD [ISTD](https://github.com/DeepInsight-PCALab/ST-CGAN)\n- ISTD+ [ISTD+](https://github.com/cvlab-stonybrook/SID)\n- SRD\n\n1. For data folder path (ISTD), train_A: shadow images, train_B: shadow masks, train_C: shadow free images, organize them as following:\n\n```shell\n--ISTD+\n   --train\n      --train_A\n          --1-1.png\n      --train_B\n          --1-1.png \n      --train_C_fixed_official \n          --1-1.png\n      --train_params_fixed  # generate later\n          --1-1.png.txt\n   --test\n      --test_A\n          --1-1.png\n      --test_B\n          --1-1.png\n      --test_C\n          --1-1.png\n      --mask_threshold   # generate later\n          --1-1.png\n ```\n \n 2. Run the code  `./data_processing/compute_params.ipynb` for exposure parameters generation. \n    The result will be put in `./ISTD/train/train_params_fixed`.\n    Here, names `train_C_fixed_official` and `train_params_fixed` are for ISTD+ dataset, which are consitent with `self.dir_C` and `self.dir_param` in                 `./data/expo_param_dataset.py` .\n 3. For testing masks, please run the code `./data_processing/test_mask_generation.py`. \n    The result will be put in `./ISTD/mask_threshold`.\n\n\n## Pretrained models\n\nWe release our pretrained model (ISTD+, SRD) at [models](https://drive.google.com/drive/folders/1riTtYvHpffYu-nqSizqSF4fhbZ2txrp5?usp=sharing)\n\npretrained model (ISTD) at [models](https://drive.google.com/drive/folders/1qECA9EjUSLMtUpN5fFZMjltQPzjp2gL9?usp=sharing)\n\nModify the parameter `model` in file `OE_eval.sh` to `Refine` and set `ks=3, n=5, rks=3` to load the model.\n\n## Train\n\nModify the corresponding path in file `OE_train.sh` and run the following script\n\n```shell\nsh OE_train.sh\n```\n1. For the parameters:\n```shell\n      DATA_PATH=./Datasets/ISTD or your datapath\n      n=5, ks=3 for FusionNet,\n      n=5, ks=3, rks=3 for RefineNet.\n      model=Fusion for FusionNet training,\n      model=Refine for RefineNet training.\n ```\n \n   The trained models are saved in `${REPO_PATH}/log/${Name}`, `Name` are customized for parameters setting.\n\n## Test\n\nIn order to test the performance of a trained model, you need to make sure that the hyper parameters in file `OE_eval.sh` match the ones in `OE_train.sh` and run the following script:\n\n```shell\nsh OE_eval.sh\n```\n1. The pretrained models are located in `${REPO_PATH}/log/${Name}`.\n\n## Evaluation\nThe results reported in the paper are calculated by the `matlab` script used in other SOTA, please see [evaluation](https://drive.google.com/file/d/1SAMqLy3dSONPgeC5ZQskPoeq60FEx9Vk/view?usp=sharing) for details. Our evaluation code will print the metrics calculated by `python` code and save the shadow removed result images which will be used by the `matlab` script.\n\n## Results\n\n- Comparsion with SOTA, see paper for details.\n\n![Framework](./images/vis_compare.png)\n\n\n- Penumbra comparsion between ours and SP+M Net\n\n![Framework](./images/edge_comparsion.png)\n\n- Testing result\n\nThe testing results on dataset ISTD+, ISTD, SRD are:[results](https://drive.google.com/drive/folders/1ubLj5r_ZMzWew4h2bNX7pQL6D62mM-dl?usp=sharing)\n\n\n**More details are coming soon**\n\n## Bibtex\n\n```\n@inproceedings{fu2021auto,\n      title={Auto-exposure Fusion for Single-image Shadow Removal}, \n      author={Lan Fu and Changqing Zhou and Qing Guo and Felix Juefei-Xu and Hongkai Yu and Wei Feng and Yang Liu and Song Wang},\n      year={2021},\n      booktitle={accepted to CVPR}\n}\n```\n"
  },
  {
    "path": "data/__init__.py",
    "content": "import importlib\nimport torch.utils.data\nfrom data.base_data_loader import BaseDataLoader\nfrom data.base_dataset import BaseDataset\n\n\ndef find_dataset_using_name(dataset_name):\n    # Given the option --dataset_mode [datasetname],\n    # the file \"data/datasetname_dataset.py\"\n    # will be imported.\n    dataset_filename = \"data.\" + dataset_name + \"_dataset\"\n    datasetlib = importlib.import_module(dataset_filename)\n\n    # In the file, the class called DatasetNameDataset() will\n    # be instantiated. It has to be a subclass of BaseDataset,\n    # and it is case-insensitive.\n    dataset = None\n    target_dataset_name = dataset_name.replace('_', '') + 'dataset'\n    for name, cls in datasetlib.__dict__.items():\n        if name.lower() == target_dataset_name.lower() \\\n           and issubclass(cls, BaseDataset):\n            dataset = cls\n\n    if dataset is None:\n        print(\"In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase.\" % (dataset_filename, target_dataset_name))\n        exit(0)\n\n    return dataset\n\n\ndef get_option_setter(dataset_name):\n    dataset_class = find_dataset_using_name(dataset_name)\n    return dataset_class.modify_commandline_options\n\n\ndef create_dataset(opt):\n    dataset = find_dataset_using_name(opt.dataset_mode)\n    instance = dataset()\n    instance.initialize(opt)\n    print(\"dataset [%s] was created\" % (instance.name()))\n    return instance\n\n\ndef CreateDataLoader(opt):\n    data_loader = CustomDatasetDataLoader()\n    data_loader.initialize(opt)\n    return data_loader\n\n\n# Wrapper class of Dataset class that performs\n# multi-threaded data loading\nclass CustomDatasetDataLoader(BaseDataLoader):\n    def name(self):\n        return 'CustomDatasetDataLoader'\n\n    def initialize(self, opt):\n        BaseDataLoader.initialize(self, opt)\n        self.dataset = create_dataset(opt)\n        self.dataloader = torch.utils.data.DataLoader(\n            self.dataset,\n            batch_size=opt.batch_size,\n            shuffle=not opt.serial_batches,\n            num_workers=int(opt.num_threads))\n\n    def load_data(self):\n        return self\n\n    def __len__(self):\n        return min(len(self.dataset), self.opt.max_dataset_size)\n\n    def __iter__(self):\n        for i, data in enumerate(self.dataloader):\n            if i * self.opt.batch_size >= self.opt.max_dataset_size:\n                break\n            yield data\n"
  },
  {
    "path": "data/base_data_loader.py",
    "content": "class BaseDataLoader():\n    def __init__(self):\n        pass\n\n    def initialize(self, opt):\n        self.opt = opt\n        pass\n\n    def load_data():\n        return None\n"
  },
  {
    "path": "data/base_dataset.py",
    "content": "import torch.utils.data as data\nfrom PIL import Image\nimport torchvision.transforms as transforms\n\n\nclass BaseDataset(data.Dataset):\n    def __init__(self):\n        super(BaseDataset, self).__init__()\n\n    def name(self):\n        return 'BaseDataset'\n\n    @staticmethod\n    def modify_commandline_options(parser, is_train):\n        return parser\n\n    def initialize(self, opt):\n        pass\n\n    def __len__(self):\n        return 0\n\n\ndef get_transform(opt):\n    transform_list = []\n    if opt.resize_or_crop == 'resize_and_crop':\n        osize = [opt.loadSize, opt.loadSize]\n        transform_list.append(transforms.Resize(osize, Image.BICUBIC))\n        transform_list.append(transforms.RandomCrop(opt.fineSize))\n    elif opt.resize_or_crop == 'crop':\n        transform_list.append(transforms.RandomCrop(opt.fineSize))\n    elif opt.resize_or_crop == 'scale_width':\n        transform_list.append(transforms.Lambda(\n            lambda img: __scale_width(img, opt.fineSize)))\n    elif opt.resize_or_crop == 'scale_width_and_crop':\n        transform_list.append(transforms.Lambda(\n            lambda img: __scale_width(img, opt.loadSize)))\n        transform_list.append(transforms.RandomCrop(opt.fineSize))\n    elif opt.resize_or_crop == 'none':\n        transform_list.append(transforms.Lambda(\n            lambda img: __adjust(img)))\n    else:\n        raise ValueError('--resize_or_crop %s is not a valid option.' % opt.resize_or_crop)\n\n    if opt.isTrain and not opt.no_flip:\n        transform_list.append(transforms.RandomHorizontalFlip())\n\n    transform_list += [transforms.ToTensor(),\n                       transforms.Normalize((0.5, 0.5, 0.5),\n                                            (0.5, 0.5, 0.5))]\n    return transforms.Compose(transform_list)\n\n\n# just modify the width and height to be multiple of 4\ndef __adjust(img):\n    ow, oh = img.size\n\n    # the size needs to be a multiple of this number,\n    # because going through generator network may change img size\n    # and eventually cause size mismatch error\n    mult = 4\n    if ow % mult == 0 and oh % mult == 0:\n        return img\n    w = (ow - 1) // mult\n    w = (w + 1) * mult\n    h = (oh - 1) // mult\n    h = (h + 1) * mult\n\n    if ow != w or oh != h:\n        __print_size_warning(ow, oh, w, h)\n\n    return img.resize((w, h), Image.BICUBIC)\n\n\ndef __scale_width(img, target_width):\n    ow, oh = img.size\n\n    # the size needs to be a multiple of this number,\n    # because going through generator network may change img size\n    # and eventually cause size mismatch error\n    mult = 4\n    assert target_width % mult == 0, \"the target width needs to be multiple of %d.\" % mult\n    if (ow == target_width and oh % mult == 0):\n        return img\n    w = target_width\n    target_height = int(target_width * oh / ow)\n    m = (target_height - 1) // mult\n    h = (m + 1) * mult\n\n    if target_height != h:\n        __print_size_warning(target_width, target_height, w, h)\n\n    return img.resize((w, h), Image.BICUBIC)\n\n\ndef __print_size_warning(ow, oh, w, h):\n    if not hasattr(__print_size_warning, 'has_printed'):\n        print(\"The image size needs to be a multiple of 4. \"\n              \"The loaded image size was (%d, %d), so it was adjusted to \"\n              \"(%d, %d). This adjustment will be done to all images \"\n              \"whose sizes are not multiples of 4\" % (ow, oh, w, h))\n        __print_size_warning.has_printed = True\n"
  },
  {
    "path": "data/expo_param_dataset.py",
    "content": "import os.path\nimport torchvision.transforms as transforms\nfrom data.base_dataset import BaseDataset, get_transform\nfrom data.image_folder import make_dataset\nfrom PIL import Image, ImageChops\nfrom PIL import ImageFilter\nimport torch\nfrom pdb import set_trace as st\nimport random\nimport numpy as np\nimport time\nimport cv2\n\n\nclass ExpoParamDataset(BaseDataset):\n    def initialize(self, opt):\n        self.opt = opt\n        self.root = opt.dataroot\n        self.dir_A = os.path.join(opt.dataroot, opt.phase + 'A')\n        # if not opt.use_our_mask:\n        if 'test' in opt.phase:\n            self.dir_B = os.path.join(opt.dataroot, 'test', 'mask_threshold') # for test masks generated\n            self.dir_C = os.path.join(opt.dataroot, opt.phase + 'C')\n        else:\n            self.dir_B = os.path.join(opt.dataroot, opt.phase + 'B')\n            self.dir_C = os.path.join(opt.dataroot, opt.phase + 'C_fixed_official')\n\n        self.dir_param = os.path.join(opt.dataroot, opt.phase + 'params_fixed')  # opt.param_path\n        print(self.dir_A)\n        self.A_paths, self.imname = make_dataset(self.dir_A)\n        self.A_size = len(self.A_paths)\n        self.B_size = self.A_size\n\n        transform_list = [transforms.ToTensor(),\n                          transforms.Normalize(mean=opt.norm_mean,\n                                               std=opt.norm_std)]\n\n        self.transformA = transforms.Compose(transform_list)\n        self.transformB = transforms.Compose([transforms.ToTensor()])\n\n        if 'train' in opt.phase:\n            self.is_train = True\n        else:\n            self.is_train = False\n\n    def __getitem__(self, index):\n        colet = {}\n        A_path = self.A_paths[index % self.A_size]\n        imname = self.imname[index % self.A_size]\n        index_A = index % self.A_size\n\n        B_path = os.path.join(self.dir_B, imname.replace('.jpg', '.png'))\n        if not os.path.isfile(B_path):\n            B_path = os.path.join(self.dir_B, imname)\n        A_img = Image.open(A_path).convert('RGB')\n\n        if self.is_train:\n            sparam = open(os.path.join(self.dir_param, imname + '.txt'))\n            line = sparam.read()\n            sparam.close()\n            shadow_param = np.asarray([float(i) for i in line.split(\" \") if i.strip()])\n            shadow_param = shadow_param[0:6]\n\n        ow = A_img.size[0]\n        oh = A_img.size[1]\n        w = np.float(A_img.size[0])\n        h = np.float(A_img.size[1])\n        if os.path.isfile(B_path):\n            B_img = Image.open(B_path)\n        else:\n            print('MASK NOT FOUND : %s' % (B_path))\n            B_img = Image.fromarray(np.zeros((int(w), int(h)), dtype=np.float), mode='L')\n\n        B_img_np = np.asarray(B_img)\n        kernel = np.ones((7, 7), np.uint8)\n        B_img_np_dilate = cv2.dilate(B_img_np, kernel, iterations=1)\n        B_img_np_erode = cv2.erode(B_img_np, kernel, iterations=1)\n        B_img_dilate = Image.fromarray(B_img_np_dilate)\n        B_img_erode = Image.fromarray(B_img_np_erode)\n        colet['B_dilate'] = B_img_dilate\n        colet['B_erode'] = B_img_erode\n\n        colet['C'] = Image.open(os.path.join(self.dir_C, imname)).convert('RGB')\n\n        loadSize = self.opt.loadSize\n        if self.is_train and self.opt.randomSize:\n            loadSize = np.random.randint(loadSize + 1, loadSize * 1.3, 1)[0]\n\n        if self.opt.keep_ratio:\n            if w > h:\n                ratio = np.float(loadSize) / np.float(h)\n                neww = np.int(w * ratio)\n                newh = loadSize\n            else:\n                ratio = np.float(loadSize) / np.float(w)\n                neww = loadSize\n                newh = np.int(h * ratio)\n        else:\n            neww = loadSize\n            newh = loadSize\n\n        colet['A'] = A_img\n        colet['B'] = B_img\n\n        if self.is_train:\n            t = [Image.FLIP_LEFT_RIGHT, Image.ROTATE_90]\n            for i in range(0, 4):\n                c = np.random.randint(0, 3, 1, dtype=np.int)[0]\n                if c == 2: continue\n                for i in ['A', 'B', 'C', 'B_dilate', 'B_erode']:\n                    if i in colet:\n                        colet[i] = colet[i].transpose(t[c])\n\n        if self.is_train:\n            degree = np.random.randint(-20, 20, 1)[0]\n            for i in ['A', 'B', 'C', 'B_dilate', 'B_erode']:\n                colet[i] = colet[i].rotate(degree)\n\n        for k, im in colet.items():\n            if self.is_train:\n                colet[k] = im.resize((neww, newh), Image.NEAREST)\n            else:\n                colet[k] = im.resize((self.opt.fineSize, self.opt.fineSize), Image.NEAREST)\n\n        w = colet['A'].size[0]\n        h = colet['A'].size[1]\n\n        for k, im in colet.items():\n            colet[k] = self.transformB(im)\n\n        for i in ['A', 'C', 'B', 'B_dilate', 'B_erode']:\n            if i in colet:\n                colet[i] = (colet[i] - 0.5) * 2\n\n        if self.is_train:  # and not self.opt.no_crop:\n            w_offset = random.randint(0, max(0, w - self.opt.fineSize - 1))\n            h_offset = random.randint(0, max(0, h - self.opt.fineSize - 1))\n            for k, im in colet.items():\n                colet[k] = im[:, h_offset:h_offset + self.opt.fineSize, w_offset:w_offset + self.opt.fineSize]\n\n        if self.is_train and (not self.opt.no_flip) and random.random() < 0.5:\n            idx = [i for i in range(colet['A'].size(2) - 1, -1, -1)]\n            idx = torch.LongTensor(idx)\n            for k, im in colet.items():\n                colet[k] = im.index_select(2, idx)\n\n        for k, im in colet.items():\n            colet[k] = im.type(torch.FloatTensor)\n\n        colet['imname'] = imname\n        colet['w'] = ow\n        colet['h'] = oh\n        colet['A_paths'] = A_path\n        colet['B_baths'] = B_path\n\n        if self.is_train:\n            # if the shadow area is too small, let's not change anything:\n            if torch.sum(colet['B'] > 0) < 30:\n                shadow_param = [0, 1, 0, 1, 0, 1]\n\n\n            colet['param'] = torch.FloatTensor(np.array(shadow_param))\n        return colet\n\n    def __len__(self):\n        return max(self.A_size, self.B_size)\n\n    def name(self):\n        return 'ExpoParamDataset'\n"
  },
  {
    "path": "data/image_folder.py",
    "content": "###############################################################################\n# Code from\n# https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py\n# Modified the original code so that it also loads images from the current\n# directory as well as the subdirectories\n###############################################################################\n\nimport torch.utils.data as data\n\nfrom PIL import Image\nimport os\nimport os.path\n\nIMG_EXTENSIONS = [\n    '.jpg', '.JPG', '.jpeg', '.JPEG',\n    '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',\n]\n\n\ndef is_image_file(filename):\n    return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)\n\ndef make_dataset(dir):\n    images = []\n    imname = []\n    assert os.path.isdir(dir), '%s is not a valid directory' % dir\n\n    for root, _, fnames in sorted(os.walk(dir)):\n        for fname in fnames:\n            if is_image_file(fname):\n                path = os.path.join(root, fname)\n                images.append(path)\n                imname.append(fname)\n    return images,imname\n\ndef default_loader(path):\n    return Image.open(path).convert('RGB')\n\n\nclass ImageFolder(data.Dataset):\n\n    def __init__(self, root, transform=None, return_paths=False,\n                 loader=default_loader):\n        imgs = make_dataset(root)\n        if len(imgs) == 0:\n            raise(RuntimeError(\"Found 0 images in: \" + root + \"\\n\"\n                               \"Supported image extensions are: \" +\n                               \",\".join(IMG_EXTENSIONS)))\n\n        self.root = root\n        self.imgs = imgs\n        self.transform = transform\n        self.return_paths = return_paths\n        self.loader = loader\n\n    def __getitem__(self, index):\n        path = self.imgs[index]\n        img = self.loader(path)\n        if self.transform is not None:\n            img = self.transform(img)\n        if self.return_paths:\n            return img, path\n        else:\n            return img\n\n    def __len__(self):\n        return len(self.imgs)\n"
  },
  {
    "path": "data/srd_dataset.py",
    "content": "import os.path\nimport torchvision.transforms as transforms\nfrom data.base_dataset import BaseDataset, get_transform\nfrom data.image_folder import make_dataset\nfrom PIL import Image, ImageChops\nfrom PIL import ImageFilter\nimport torch\nfrom pdb import set_trace as st\nimport random\nimport numpy as np\nimport time\nimport cv2\n\n\nclass SRDDataset(BaseDataset):\n    def initialize(self, opt):\n        self.opt = opt\n        self.root = opt.dataroot\n\n        if 'train' in opt.phase:\n            phase = 'train'\n        elif 'test' in opt.phase:\n            phase = 'test'\n        else:\n            raise False\n\n        self.dir_A = os.path.join(opt.dataroot, phase, 'shadow')\n        self.dir_B = os.path.join(opt.dataroot, phase, 'mask')\n        print(self.dir_B)\n        self.dir_C = os.path.join(opt.dataroot, phase, 'shadow_free')\n        self.dir_param = os.path.join(opt.dataroot, phase, 'mask_params')\n        print(self.dir_A)\n        self.A_paths, self.imname = make_dataset(self.dir_A)\n        self.A_size = len(self.A_paths)\n        self.B_size = self.A_size\n\n        transform_list = [transforms.ToTensor(),\n                          transforms.Normalize(mean=opt.norm_mean,\n                                               std=opt.norm_std)]\n\n        self.transformA = transforms.Compose(transform_list)\n        self.transformB = transforms.Compose([transforms.ToTensor()])\n\n        if 'train' in opt.phase:\n            self.is_train = True\n        else:\n            self.is_train = False\n\n    def __getitem__(self, index):\n        colet = {}\n        A_path = self.A_paths[index % self.A_size]\n        imname = self.imname[index % self.A_size]\n        index_A = index % self.A_size\n\n        B_path = os.path.join(self.dir_B, imname.replace('.jpg', '.png'))\n        if not os.path.isfile(B_path):\n            B_path = os.path.join(self.dir_B, imname)\n        A_img = Image.open(A_path).convert('RGB')\n\n        # if self.is_train:\n        sparam = open(os.path.join(self.dir_param, imname + '.txt'))\n        line = sparam.read()\n        sparam.close()\n        shadow_param = np.asarray([float(i) for i in line.split(\" \") if i.strip()])\n        shadow_param = shadow_param[0:6]\n\n        ow = A_img.size[0]\n        oh = A_img.size[1]\n        w = np.float(A_img.size[0])\n        h = np.float(A_img.size[1])\n        if os.path.isfile(B_path):\n            B_img = Image.open(B_path)\n        else:\n            print('MASK NOT FOUND : %s' % (B_path))\n            B_img = Image.fromarray(np.zeros((int(w), int(h)), dtype=np.float), mode='L')\n        if self.is_train:\n            colet['C'] = Image.open(os.path.join(self.dir_C, imname.replace('.jpg', '_no_shadow.jpg'))).convert('RGB')\n        else:\n            colet['C'] = Image.open(os.path.join(self.dir_C, imname.replace('.jpg', '_free.jpg'))).convert('RGB')\n\n        B_img_np = np.asarray(B_img)\n        kernel = np.ones((7, 7), np.uint8)\n        B_img_np_dilate = cv2.dilate(B_img_np, kernel, iterations=1)\n        B_img_np_erode = cv2.erode(B_img_np, kernel, iterations=1)\n        B_img_dilate = Image.fromarray(B_img_np_dilate)\n        B_img_erode = Image.fromarray(B_img_np_erode)\n        colet['B_dilate'] = B_img_dilate\n        colet['B_erode'] = B_img_erode\n\n        loadSize = self.opt.loadSize\n        if self.is_train and self.opt.randomSize:\n            loadSize = np.random.randint(loadSize + 1, loadSize * 1.3, 1)[0]\n\n        if self.opt.keep_ratio:\n            if w > h:\n                ratio = np.float(loadSize) / np.float(h)\n                neww = np.int(w * ratio)\n                newh = loadSize\n            else:\n                ratio = np.float(loadSize) / np.float(w)\n                neww = loadSize\n                newh = np.int(h * ratio)\n        else:\n            neww = loadSize\n            newh = loadSize\n\n        colet['A'] = A_img\n        colet['B'] = B_img\n\n        if self.is_train:\n            t = [Image.FLIP_LEFT_RIGHT, Image.ROTATE_90]\n            for i in range(0, 4):\n                c = np.random.randint(0, 3, 1, dtype=np.int)[0]\n                if c == 2: continue\n                for i in ['A', 'B', 'C', 'B_dilate', 'B_erode']:\n                    if i in colet:\n                        colet[i] = colet[i].transpose(t[c])\n\n        if self.is_train:\n            degree = np.random.randint(-20, 20, 1)[0]\n            for i in ['A', 'B', 'C', 'B_dilate', 'B_erode']:\n                colet[i] = colet[i].rotate(degree)\n\n        for k, im in colet.items():\n            if self.is_train:\n                colet[k] = im.resize((neww, newh), Image.NEAREST)\n            else:\n                colet[k] = im.resize((self.opt.fineSize, self.opt.fineSize), Image.NEAREST)\n\n        w = colet['A'].size[0]\n        h = colet['A'].size[1]\n\n        for k, im in colet.items():\n            colet[k] = self.transformB(im)\n\n        for i in ['A', 'C', 'B', 'B_dilate', 'B_erode']:\n            if i in colet:\n                colet[i] = (colet[i] - 0.5) * 2\n\n        if self.is_train:  # and not self.opt.no_crop:\n            w_offset = random.randint(0, max(0, w - self.opt.fineSize - 1))\n            h_offset = random.randint(0, max(0, h - self.opt.fineSize - 1))\n            for k, im in colet.items():\n                colet[k] = im[:, h_offset:h_offset + self.opt.fineSize, w_offset:w_offset + self.opt.fineSize]\n\n        if self.is_train and (not self.opt.no_flip) and random.random() < 0.5:\n            idx = [i for i in range(colet['A'].size(2) - 1, -1, -1)]\n            idx = torch.LongTensor(idx)\n            for k, im in colet.items():\n                colet[k] = im.index_select(2, idx)\n\n        for k, im in colet.items():\n            colet[k] = im.type(torch.FloatTensor)\n        colet['imname'] = imname\n        colet['w'] = ow\n        colet['h'] = oh\n        colet['A_paths'] = A_path\n        colet['B_baths'] = B_path\n\n        # if the shadow area is too small, let's not change anything:\n        if torch.sum(colet['B'] > 0) < 30:\n            shadow_param = [0, 1, 0, 1, 0, 1]\n\n        colet['param'] = torch.FloatTensor(np.array(shadow_param))\n        return colet\n\n    def __len__(self):\n        return max(self.A_size, self.B_size)\n\n    def name(self):\n        return 'SRDDataset'\n"
  },
  {
    "path": "data_processing/compute_params.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 48,\n   \"metadata\": {\n    \"scrolled\": true\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"1330 13-14.png\\n\",\n      \"5.009201165757417 [ 2.4577046 -0.0107077] [1.96292375 0.03742387] [1.66641561 0.05142359]\\n\",\n      \"12.557166108231598 [1.         0.47089183] [1.01141423 0.40761759] [1.         0.35819274]\\n\",\n      \"7.532163787052489 [1.36470302 0.11617836] [1.35946438 0.07884616] [1.2162239  0.09544635]\\n\",\n      \"19.61040195664874 [1.         0.46429342] [1.         0.41372481] [1.         0.35411328]\\n\",\n      \"12.387102593381845 [1.         0.40474525] [1.         0.36618001] [1.         0.30457438]\\n\",\n      \"6.241485039361433 [1.24169934 0.24815966] [1.20049734 0.21402725] [1.20519596 0.12513773]\\n\",\n      \"5.785971619706131 [1.51092218 0.34851648] [1.21188995 0.34101258] [1.02058334 0.29614014]\\n\",\n      \"17.118862751395312 [1.66188257 0.26537164] [1.35520425 0.28154068] [1.77398978 0.11631677]\\n\",\n      \"4.218141292961971 [1.13025247 0.38224351] [1.10667088 0.33890229] [1.         0.31539165]\\n\",\n      \"11.894605985113497 [1.3901423  0.18765527] [1.35044863 0.16179696] [1.15484409 0.16976616]\\n\",\n      \"11.361911276969092 [1.36932553 0.20350377] [1.18084095 0.20083562] [1.14959475 0.12058902]\\n\",\n      \"9.86021931618332 [1.        0.3752986] [1.         0.32407264] [1.         0.26820877]\\n\",\n      \"10.947278660158846 [1.20420267 0.5       ] [1.02247322 0.5       ] [1.         0.43628735]\\n\",\n      \"6.657621145978359 [1.         0.20718372] [1.         0.17591685] [1.         0.13735604]\\n\",\n      \"8.03704706604437 [1.32010748 0.31497458] [1.24215437 0.27602836] [1.0973914 0.2536129]\\n\",\n      \"5.482880754502581 [1.18206069 0.3588552 ] [1.1275162 0.2879413] [1.05704887 0.21233905]\\n\",\n      \"5.718165073349755 [1.29934299 0.16261593] [1.25885329 0.13936101] [1.1877504  0.13852079]\\n\",\n      \"7.016956595464088 [1.25482309 0.37105025] [1.08263569 0.36350617] [1.         0.33235748]\\n\",\n      \"6.897566305772376 [1.10145958 0.23318679] [1.09468612 0.19643508] [1.13940239 0.10195221]\\n\",\n      \"8.532390927615896 [1.00302409 0.30406921] [1.0087635  0.26458958] [1.04911879 0.1827265 ]\\n\",\n      \"11.39396945909047 [1.         0.40152748] [1.         0.36691934] [1.         0.32694646]\\n\",\n      \"11.583954785730356 [1.         0.46116362] [1.         0.42340029] [1.         0.36889922]\\n\",\n      \"13.963786390637335 [1.1974282  0.24911136] [1.16107098 0.21162959] [1.13311731 0.16491704]\\n\",\n      \"6.253649587591729 [1.00713065 0.25310759] [1.0224591  0.21871152] [1.02950319 0.17496982]\\n\",\n      \"3.6175409524276714 [1.         0.21817873] [1.         0.20064313] [1.         0.17062906]\\n\",\n      \"14.433281122388786 [1.         0.37965783] [1.         0.33575514] [1.         0.27812208]\\n\",\n      \"5.608802129687688 [1.7371973  0.36643941] [1.85441399 0.28241099] [1.35207175 0.28317621]\\n\",\n      \"9.114906090452688 [1.8711927  0.36635029] [1.61305913 0.34840897] [1.50288754 0.26483322]\\n\",\n      \"13.400176175652874 [1.         0.16318474] [1.         0.14702124] [1.09372213 0.04657046]\\n\",\n      \"9.759745110892707 [1.         0.32173525] [1.         0.30398725] [1.         0.27439679]\\n\",\n      \"2.8514657932585674 [1.25887751 0.14700267] [1.15925258 0.15962704] [1.12362727 0.14751425]\\n\",\n      \"9.679825820689372 [1.         0.44854566] [1.         0.40815696] [1.         0.34736059]\\n\",\n      \"11.073786320887729 [1.         0.45058488] [1.         0.42068032] [1.         0.37156344]\\n\",\n      \"10.289516652242412 [1.         0.38406976] [1.         0.31952146] [1.00054163 0.24363679]\\n\",\n      \"3.023031724227071 [1.05675238 0.29276224] [1.04012443 0.25366816] [1.         0.21798194]\\n\",\n      \"6.168235449953648 [1.27013745 0.23122278] [1.35899604 0.12755769] [1.26931699 0.10761264]\\n\",\n      \"7.248479700487893 [1.         0.35115461] [1.        0.3240957] [1.         0.28089445]\\n\",\n      \"11.554876325373565 [1.         0.41879061] [1.         0.37834907] [1.         0.32560303]\\n\",\n      \"11.26561615969872 [1.         0.32843726] [1.         0.30285016] [1.         0.25962058]\\n\",\n      \"4.729147299093163 [1.14625236 0.45816921] [1.05362448 0.41870127] [1.         0.35341439]\\n\",\n      \"4.691457132264248 [1.60396452 0.28880753] [1.52081418 0.2326116 ] [1.12194682 0.22492234]\\n\",\n      \"5.084197838597331 [1.36933905 0.41946407] [1.15912625 0.40275493] [1.         0.36110154]\\n\",\n      \"5.17239733776305 [1.         0.41599939] [1.        0.3926829] [1.        0.3514369]\\n\",\n      \"10.250459697575046 [1.         0.33271481] [1.         0.30857566] [1.        0.2642584]\\n\",\n      \"6.023400035645858 [1.65635511 0.13666266] [1.59200143 0.11156766] [1.45946394 0.10254822]\\n\",\n      \"6.96555858019867 [1.         0.34628316] [1.         0.32011929] [1.         0.27775025]\\n\",\n      \"10.783664343047759 [1.         0.38274873] [1.         0.31831562] [1.01258584 0.23625687]\\n\",\n      \"9.495411617394018 [1.         0.44816713] [1.         0.41069855] [1.         0.35110202]\\n\",\n      \"14.09482520738358 [1.         0.45434453] [1.         0.42086814] [1.         0.36903418]\\n\",\n      \"10.68465403611464 [1.         0.32654092] [1.         0.30548285] [1.        0.2727354]\\n\",\n      \"5.302505334273771 [1.23252662 0.17808913] [1.1569384  0.17391497] [1.13249722 0.15741759]\\n\",\n      \"9.652637898854056 [1.54465585 0.37706192] [1.40578235 0.35342029] [1.31892734 0.27949605]\\n\",\n      \"13.382923736675977 [1.         0.15977668] [1.         0.14423596] [1.08302745 0.03914978]\\n\",\n      \"5.628477691117605 [1.         0.30895195] [1.01787094 0.27018671] [1.06078142 0.20326901]\\n\",\n      \"10.542397198705956 [1.         0.33818339] [1.         0.28637546] [1.        0.2347259]\\n\",\n      \"12.394137628705234 [1.         0.40037421] [1.         0.34983589] [1.         0.28427398]\\n\",\n      \"6.105189830657419 [1.33760495 0.48293045] [1.15418303 0.45610392] [1.         0.41235753]\\n\",\n      \"11.044696194770463 [1.         0.41055094] [1.         0.37572367] [1.         0.33471414]\\n\",\n      \"5.078613467699998 [1.07033551 0.22390123] [1.0774847  0.17448601] [1.10485065 0.09423847]\\n\",\n      \"9.706465582443778 [1.25644739 0.20911097] [1.21251639 0.18062476] [1.21865354 0.09875737]\\n\",\n      \"6.362019292228281 [1.35115594 0.15168153] [1.28841872 0.13882535] [1.18778273 0.14618269]\\n\",\n      \"6.704423057174779 [1.24648772 0.36934891] [1.08134669 0.35938695] [1.         0.32914258]\\n\",\n      \"7.251582796373196 [1.29143978 0.32358843] [1.15770532 0.29475572] [1.06014823 0.25794168]\\n\",\n      \"5.6684857048323405 [1.28888514 0.34600941] [1.20387569 0.27758208] [1.13847092 0.19283189]\\n\",\n      \"6.796616557592407 [1.         0.21171769] [1.         0.17741299] [1.         0.13329753]\\n\",\n      \"12.05718679827463 [1.01485253 0.5       ] [1.14592859 0.41653497] [1.11974356 0.35647225]\\n\",\n      \"8.769923956648649 [1.         0.38290973] [1.         0.32841014] [1.         0.27251493]\\n\",\n      \"10.697150010065231 [1.27189658 0.2268371 ] [1.1559388 0.2066911] [1.11178238 0.13857463]\\n\",\n      \"4.379469658369307 [1.16802842 0.36942668] [1.09347926 0.34116608] [1.         0.30731457]\\n\",\n      \"13.482325835777495 [1.54434777 0.1587063 ] [1.52810059 0.12175819] [1.36822658 0.12335494]\\n\",\n      \"5.894642644532999 [1.5203849  0.34331712] [1.24154046 0.3330313 ] [1.07096582 0.28564876]\\n\",\n      \"16.8750240387132 [1.62030722 0.2800683 ] [1.38151402 0.28671642] [1.62948089 0.14264266]\\n\",\n      \"8.124019704549188 [1.151602   0.28472924] [1.12190576 0.24737406] [1.10995035 0.16743737]\\n\",\n      \"18.279280570987343 [1.         0.37643914] [1.         0.34208776] [1.         0.28539649]\\n\",\n      \"19.66636441325457 [1.         0.46023776] [1.        0.4140824] [1.         0.35460682]\\n\",\n      \"6.357859224374215 [1.37712007 0.10885838] [1.38667765 0.0638982 ] [1.2322195 0.0838116]\\n\",\n      \"9.206116831806321 [1.78888139 0.04242404] [1.56327996 0.06241268] [1.3965774  0.06631569]\\n\",\n      \"13.076188356544833 [1.        0.4695472] [1.03514895 0.39739679] [1.         0.35989896]\\n\",\n      \"13.199088640671294 [1.         0.46120124] [1.03512727 0.38995599] [1.         0.34793226]\\n\",\n      \"14.463177949442647 [1.61275421 0.12496326] [1.48326322 0.1062259 ] [1.24045911 0.12648806]\\n\",\n      \"18.107028250869547 [1.         0.46686428] [1.         0.42096719] [1.05246637 0.35092622]\\n\",\n      \"3.5529687682538538 [1.05842061 0.26812992] [1.         0.27930972] [1.         0.18323113]\\n\",\n      \"14.566396520709839 [1.         0.39312816] [1.         0.35634256] [1.         0.29522218]\\n\",\n      \"12.528555453691613 [1.34846392 0.19787606] [1.32235555 0.1730581 ] [1.11582986 0.18224353]\\n\",\n      \"5.521435314014978 [1.15536791 0.38957656] [1.06943945 0.37053483] [1.         0.33839323]\\n\",\n      \"16.30208706976876 [1.71219158 0.26245386] [1.42384526 0.27685767] [1.77350493 0.11706204]\\n\",\n      \"6.493081806732496 [1.57368805 0.34476957] [1.29733911 0.32246715] [1.02607377 0.29064953]\\n\",\n      \"6.131783701655587 [1.02382791 0.23680186] [1.         0.21640777] [1.         0.17907762]\\n\",\n      \"7.506828808097932 [1.19557558 0.2218131 ] [1.16212076 0.20064528] [1.11232589 0.15949729]\\n\",\n      \"11.161900229477755 [1.48012875 0.15105422] [1.25319232 0.16843267] [1.23539921 0.07926511]\\n\",\n      \"8.12040700176336 [1.42825465 0.5       ] [1.36049121 0.47434245] [1.17884733 0.43862526]\\n\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"9.132013795788252 [1.         0.34806853] [1.         0.31731916] [1.         0.27023051]\\n\",\n      \"9.792317679752305 [1.         0.36588639] [1.       0.314011] [1.         0.25793671]\\n\",\n      \"5.940965701437201 [1.         0.20803771] [1.         0.17323618] [1.         0.13119351]\\n\",\n      \"9.882733508141538 [1.30672327 0.19066968] [1.11155731 0.21128187] [1.1492221  0.10612051]\\n\",\n      \"10.405425120944802 [1.         0.43826299] [1.         0.39470588] [1.         0.34912023]\\n\",\n      \"7.060083238993572 [1.3451903  0.17801084] [1.27648049 0.16115865] [1.19106949 0.16273088]\\n\",\n      \"7.1223147442726695 [1.11513652 0.36908961] [1.55617714 0.18893727] [1.23582137 0.15631021]\\n\",\n      \"9.293580082372218 [1.         0.32047473] [1.        0.2833482] [1.         0.23739715]\\n\",\n      \"8.005524059669257 [1.28987369 0.21332551] [1.25316586 0.18238389] [1.20759584 0.11579795]\\n\",\n      \"7.76911158735872 [1.21623172 0.20625933] [1.18067534 0.18222335] [1.16642815 0.10963884]\\n\",\n      \"5.1918998443205195 [1.66332463 0.41683789] [1.39129656 0.3901788 ] [1.11576509 0.35980625]\\n\",\n      \"10.69094639243862 [1.         0.39350621] [1.        0.3582786] [1.         0.31643687]\\n\",\n      \"13.225567054079825 [1.         0.15839072] [1.         0.14323687] [1.10139522 0.04279168]\\n\",\n      \"12.342932240171312 [1.        0.3664039] [1.         0.32446021] [1.         0.27077646]\\n\",\n      \"11.043933890756614 [1.59163681 0.38698595] [1.49561461 0.35451359] [1.41116016 0.28198947]\\n\",\n      \"9.094995620547467 [1.15021046 0.47727358] [1.08977463 0.42953044] [1.00129401 0.38367014]\\n\",\n      \"14.125721761128755 [1.19515688 0.27054272] [1.15095819 0.23091429] [1.09924844 0.19193861]\\n\",\n      \"3.6679252912694214 [1.         0.20445415] [1.         0.18896712] [1.         0.16061438]\\n\",\n      \"12.062789162656694 [1.         0.46486263] [1.         0.43506065] [1.02549722 0.38234206]\\n\",\n      \"9.68575062818469 [1.        0.3033293] [1.         0.26744613] [1.         0.22396062]\\n\",\n      \"9.809815320535373 [1.         0.38113812] [1.         0.31209012] [1.         0.22537621]\\n\",\n      \"10.430086093962018 [1.72870698 0.28330771] [1.83324658 0.20519742] [1.60239585 0.18656398]\\n\",\n      \"5.113332299433989 [1.46823856 0.30763214] [1.33919629 0.26388057] [1.08677162 0.23732179]\\n\",\n      \"16.71484743065607 [1.1872556 0.1499857] [1.         0.20026275] [1.32602403 0.03816908]\\n\",\n      \"3.9663087281319838 [1.5753108 0.1781449] [1.51843292 0.16357543] [1.31442388 0.16156391]\\n\",\n      \"18.307070983478926 [1.18915999 0.15383811] [1.         0.20327953] [1.38977648 0.03533171]\\n\",\n      \"7.253530355545665 [1.65216101 0.15262161] [1.60852392 0.1270196 ] [1.50196351 0.10690405]\\n\",\n      \"9.151588863604836 [1.89008521 0.27414988] [1.9818452  0.19755024] [1.69001509 0.19157092]\\n\",\n      \"5.079656461406473 [1.30158568 0.32988674] [1.19199073 0.29013097] [1.17694334 0.20035483]\\n\",\n      \"10.08549668679645 [1.         0.38304066] [1.10779293 0.28774702] [1.03929042 0.21792944]\\n\",\n      \"11.615106346615681 [1.        0.2876923] [1.        0.2568599] [1.         0.21963735]\\n\",\n      \"14.363588013460268 [1.        0.4534733] [1.         0.41606513] [1.         0.35647859]\\n\",\n      \"8.741799427958053 [1.02639393 0.5       ] [1.         0.45461541] [1.        0.3876333]\\n\",\n      \"4.187578833825149 [1.         0.19033489] [1.         0.17401085] [1.         0.14676121]\\n\",\n      \"13.876034870909518 [1.53702276 0.15078739] [1.34135633 0.15243525] [1.31202896 0.1029404 ]\\n\",\n      \"10.569313277856985 [1.00796339 0.36697105] [1.00255843 0.32743355] [1.0288915 0.2625588]\\n\",\n      \"5.246471838025861 [1.68078264 0.42014281] [1.32655751 0.40559491] [1.01216441 0.3884433 ]\\n\",\n      \"8.273949420403152 [1.33623799 0.19335715] [1.29346782 0.16115181] [1.26786039 0.08569809]\\n\",\n      \"5.059069389593845 [1.18342216 0.35402176] [1.18418688 0.27000158] [1.07426841 0.20375367]\\n\",\n      \"7.76921206807748 [1.378397   0.28277659] [1.21494323 0.26918975] [1.10196424 0.23397128]\\n\",\n      \"6.43607447230013 [1.56492665 0.29695574] [1.28404774 0.29864386] [1.06882438 0.29462784]\\n\",\n      \"6.398037303945606 [1.45403557 0.14811863] [1.32537206 0.15033624] [1.20354689 0.15931622]\\n\",\n      \"9.55648203451093 [1.94467366 0.08649169] [1.76493891 0.1005126 ] [1.69881545 0.01980189]\\n\",\n      \"10.856624268671712 [1.        0.3673825] [1.         0.31482481] [1.         0.25988508]\\n\",\n      \"8.982680052796217 [1.         0.34127813] [1.         0.31518449] [1.         0.27009608]\\n\",\n      \"6.555164803237054 [1.2562734  0.20015526] [1.22347882 0.18298066] [1.14746297 0.15280522]\\n\",\n      \"11.42983616950065 [1.34457361 0.19784701] [1.19761175 0.18340349] [1.12131573 0.12921931]\\n\",\n      \"6.181647983118242 [1.02393465 0.23669126] [1.        0.2158752] [1.         0.17661177]\\n\",\n      \"14.264335481464023 [1.36231952 0.3193632 ] [1.2304746  0.29892214] [1.36514555 0.1746504 ]\\n\",\n      \"6.572005530190152 [1.39173968 0.3702036 ] [1.17530396 0.34109917] [1.00265985 0.2970286 ]\\n\",\n      \"8.319346420255288 [1.52055688 0.1575683 ] [1.39921609 0.14575519] [1.16881791 0.15182689]\\n\",\n      \"5.611568650755693 [1.1495796  0.37269716] [1.14226881 0.32804759] [1.06802974 0.29075461]\\n\",\n      \"18.01291292936981 [1.         0.37746319] [1.         0.34310854] [1.         0.28435723]\\n\",\n      \"3.4836986777638694 [1.52675188 0.14018194] [1.         0.28277446] [1.         0.18527186]\\n\",\n      \"4.3317783224781605 [1.12365544 0.22748433] [1.17878037 0.16106564] [1.22965427 0.09299422]\\n\",\n      \"13.336767515409255 [1.         0.46458256] [1.12057777 0.37783388] [1.00717876 0.35011677]\\n\",\n      \"9.864917943383352 [1.6404268  0.06709127] [1.65048543 0.04040782] [1.34801337 0.0844676 ]\\n\",\n      \"16.846284924524845 [1.13472825 0.4631835 ] [1.21363804 0.3932583 ] [1.17821932 0.32699423]\\n\",\n      \"5.358754536184186 [1.23806703 0.17083908] [1.33126131 0.09281625] [1.31528204 0.05769854]\\n\",\n      \"10.353051767274382 [1.51270939 0.10647843] [1.654621   0.05133744] [1.41061676 0.07857836]\\n\",\n      \"4.648328470473069 [1.11023048 0.39100936] [1.04205856 0.36523198] [1.         0.32489038]\\n\",\n      \"13.316994078985502 [1.31008275 0.19355857] [1.42521755 0.13546461] [1.19475776 0.15705117]\\n\",\n      \"9.819794122576932 [1.         0.46014678] [1.         0.41603047] [1.         0.34684396]\\n\",\n      \"16.242143535088573 [1.         0.37921593] [1.         0.34364416] [1.       0.282832]\\n\",\n      \"9.194804363968492 [1.         0.13570096] [1.         0.10512691] [1.         0.08188649]\\n\",\n      \"11.379231412644675 [1.18159074 0.45204766] [1.33128828 0.36451897] [1.19303512 0.32488503]\\n\",\n      \"9.069394621391316 [1.         0.38514646] [1.08043136 0.30390721] [1.         0.27154333]\\n\",\n      \"6.846509829714752 [1.32147852 0.20677532] [1.28256451 0.18644628] [1.20912227 0.1462376 ]\\n\",\n      \"10.966962761678616 [1.20763463 0.24372572] [1.14493124 0.2091043 ] [1.09776285 0.1472699 ]\\n\",\n      \"6.6067592297084135 [1.11052289 0.13083523] [1.04818177 0.1221258 ] [1.12051966 0.04052324]\\n\",\n      \"7.872035907686601 [1.28375202 0.20551887] [1.24402502 0.18556539] [1.1672168  0.18399265]\\n\",\n      \"4.329569989383152 [1.         0.19670338] [1.         0.15933851] [1.         0.05613744]\\n\",\n      \"7.17302436851573 [1.27467707 0.27052945] [1.39327588 0.20646313] [1.22131733 0.21167575]\\n\",\n      \"11.16483308047798 [1.47515678 0.26839173] [1.29870927 0.25360881] [1.24493635 0.19210514]\\n\",\n      \"5.319961589415595 [1.4293636  0.31697304] [1.34503128 0.24424028] [1.01072838 0.22989745]\\n\",\n      \"5.6766943497891305 [1.         0.19707673] [1.         0.16416763] [1.         0.11920103]\\n\",\n      \"7.718060246146565 [1.9227159  0.45543079] [1.73850337 0.40345459] [1.38004744 0.35508254]\\n\",\n      \"3.343854000457899 [1.         0.21764038] [1.        0.2100391] [1.         0.18451299]\\n\",\n      \"10.516698750314097 [1.         0.40731168] [1.         0.36490032] [1.         0.32373622]\\n\",\n      \"7.115375375306831 [1.15684331 0.25528336] [1.16037676 0.19798879] [1.12641917 0.12297683]\\n\",\n      \"5.948536479182282 [1.10257828 0.21344633] [1.09389845 0.17504921] [1.08962649 0.11284049]\\n\",\n      \"6.87330452421931 [1.26082107 0.23185223] [1.22121265 0.1967585 ] [1.23299833 0.10899838]\\n\",\n      \"9.963364348319965 [2.10171282 0.32908377] [1.89253652 0.30683977] [1.71861653 0.23969449]\\n\",\n      \"12.034387592660073 [1.         0.15319226] [1.         0.13411761] [1.06743429 0.0451563 ]\\n\",\n      \"5.109847361744034 [1.         0.23961853] [1.00441933 0.21933302] [1.         0.18597711]\\n\",\n      \"9.18755024465854 [1.62144093 0.39797098] [1.64255571 0.33403221] [1.31080648 0.32105157]\\n\",\n      \"11.399652647647011 [1.04295124 0.47268458] [1.00648189 0.44236597] [1.         0.39338003]\\n\",\n      \"17.278222742674775 [1.19422497 0.29510424] [1.06277835 0.29093328] [1.19985579 0.14806508]\\n\",\n      \"5.493054634920303 [1.         0.12229506] [1.         0.10049863] [1.03689346 0.0570521 ]\\n\",\n      \"4.65671368314793 [1.4971448  0.30378396] [1.28672537 0.27268689] [1.14284681 0.21108244]\\n\",\n      \"10.60311527962159 [1.00474998 0.41151018] [1.123333   0.34908594] [1.07609795 0.30092226]\\n\",\n      \"9.807744768289167 [1.49269343 0.36348981] [1.49881221 0.30399422] [1.4106148  0.24874052]\\n\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"8.229391799374683 [1.11326697 0.28303568] [1.22700416 0.18165858] [1.18957228 0.16675759]\\n\",\n      \"5.817884163170584 [1.01720958 0.32880787] [1.02453528 0.30305116] [1.00257847 0.26936701]\\n\",\n      \"11.327098166250956 [1.03155823 0.25565992] [1.05183265 0.21572989] [1.02914384 0.1852006 ]\\n\",\n      \"7.149110333966241 [1.1455228  0.21281691] [1.13961522 0.20255734] [1.13505086 0.16736176]\\n\",\n      \"10.3813820806001 [1.         0.38640003] [1.         0.32192303] [1.02079451 0.23596475]\\n\",\n      \"10.112401221394446 [1.         0.37742396] [1.15628001 0.27022072] [1.06136077 0.19953287]\\n\",\n      \"8.233764537328963 [1.13845694 0.24339034] [1.13403178 0.2055601 ] [1.06929967 0.18204631]\\n\",\n      \"7.166512160354089 [1.13639967 0.20977336] [1.13292441 0.19817585] [1.13138368 0.16510997]\\n\",\n      \"9.43711205936275 [1.         0.33920725] [1.19067778 0.18427091] [1.17262074 0.15875156]\\n\",\n      \"5.7627867405678685 [1.         0.33192273] [1.         0.30700443] [1.         0.26494425]\\n\",\n      \"5.799446886327617 [1.55970436 0.28742759] [1.28359163 0.27651333] [1.26502914 0.17683042]\\n\",\n      \"10.664803770222768 [1.         0.41277196] [1.         0.37605694] [1.         0.32483377]\\n\",\n      \"8.616222839391543 [1.74020995 0.29400656] [1.69706417 0.23602105] [1.48933511 0.20889967]\\n\",\n      \"6.99952774258929 [1.01321609 0.11497411] [1.00596574 0.10385598] [1.07796247 0.04209424]\\n\",\n      \"19.28470232275173 [1.32517208 0.25716803] [1.24178746 0.24519606] [1.40388588 0.10476253]\\n\",\n      \"10.585146058065424 [1.10167698 0.46575265] [1.04452282 0.43795931] [1.         0.39002365]\\n\",\n      \"3.9971776903870557 [1.         0.21280227] [1.00979397 0.1903038 ] [1.         0.16226249]\\n\",\n      \"8.600372691630932 [1.20916714 0.45718259] [1.22017397 0.399884  ] [1.06370449 0.36711354]\\n\",\n      \"10.50748004297346 [2.21778785 0.33338072] [1.95175535 0.3162305 ] [1.80184223 0.2410771 ]\\n\",\n      \"11.479631222666933 [1.         0.16122589] [1.         0.14020892] [1.07727474 0.0526391 ]\\n\",\n      \"6.009275495328019 [1.33026588 0.18172235] [1.27577272 0.15615949] [1.16958633 0.11121259]\\n\",\n      \"7.4231429910093 [1.25461666 0.23244019] [1.20777819 0.20047267] [1.22432189 0.10797243]\\n\",\n      \"10.002934277973846 [1.        0.3920379] [1.         0.35239683] [1.         0.31007609]\\n\",\n      \"3.9568556633323704 [1.         0.19923706] [1.         0.19080906] [1.         0.16440974]\\n\",\n      \"3.1996850236422163 [1.00566765 0.20323289] [1.02180513 0.16075047] [1.02813139 0.11744938]\\n\",\n      \"7.670147666302381 [1.32762098 0.30623213] [1.24472104 0.27057504] [1.17025815 0.22730785]\\n\",\n      \"6.042809587371163 [1.38519011 0.32888136] [1.26299585 0.2628564 ] [1.06580149 0.21291402]\\n\",\n      \"7.757869166937808 [1.07987583 0.1371057 ] [1.03446946 0.12497328] [1.10348111 0.04723894]\\n\",\n      \"4.43258421132082 [1.11721103 0.16522905] [1.12024512 0.12830612] [1.         0.06038772]\\n\",\n      \"11.271611030082378 [1.3841545  0.17942604] [1.19712893 0.18344897] [1.19232273 0.09630849]\\n\",\n      \"3.698315232045804 [1.        0.3029356] [1.         0.29169131] [1.02321683 0.20416274]\\n\",\n      \"8.822285100363505 [1.         0.38251011] [1.14003079 0.28550306] [1.         0.26887895]\\n\",\n      \"6.753251142712777 [1.3186301  0.19447349] [1.277587   0.17462457] [1.19531644 0.13935587]\\n\",\n      \"11.206065921234933 [1.28159469 0.43717107] [1.55191477 0.32874611] [1.28062657 0.30535132]\\n\",\n      \"16.33983170885009 [1.         0.37769077] [1.         0.34479162] [1.         0.28664752]\\n\",\n      \"6.295083971343806 [1.87053536 0.32743493] [1.58078417 0.30719014] [1.21212806 0.28227725]\\n\",\n      \"9.656178527992829 [1.         0.42953287] [1.         0.39333673] [1.         0.32904069]\\n\",\n      \"16.617923937413938 [1.88175886 0.24733114] [1.57921308 0.24949027] [1.65726625 0.13864633]\\n\",\n      \"5.390447352010208 [1.13769409 0.39439566] [1.05165044 0.36885706] [1.         0.32786999]\\n\",\n      \"13.459337896493485 [1.37798936 0.18296964] [1.56263841 0.12154884] [1.33853149 0.12959174]\\n\",\n      \"7.970838933707072 [1.79384372 0.05637351] [1.63839352 0.05992842] [1.45749686 0.05608501]\\n\",\n      \"6.1610743515074775 [1.27527954 0.14611874] [1.34536226 0.08104233] [1.272146   0.07017626]\\n\",\n      \"16.583950762246825 [1.20690134 0.45788159] [1.26357457 0.3906088 ] [1.24508291 0.32545267]\\n\",\n      \"6.2817083525225215 [1.43248899 0.08235235] [1.46319379 0.03151225] [1.34056285 0.04330153]\\n\",\n      \"15.46772954580591 [1.20037753 0.45752591] [1.28222324 0.38221786] [1.22166844 0.31956128]\\n\",\n      \"12.387730736901757 [1.37983368 0.15882999] [1.63323002 0.07241884] [1.4654103  0.08531616]\\n\",\n      \"14.137613449311157 [1.32896481 0.32018107] [1.21052352 0.29922719] [1.3130999  0.18296167]\\n\",\n      \"6.64786717329171 [1.24780339 0.3833198 ] [1.10150238 0.35102949] [1.04630689 0.27798395]\\n\",\n      \"4.24135722166321 [1.07243444 0.0987732 ] [1.08612152 0.06698599] [1.03942819 0.05278534]\\n\",\n      \"16.15573306961432 [1.         0.40620807] [1.         0.38199278] [1.         0.34536175]\\n\",\n      \"15.733509098532744 [1.         0.38269878] [1.         0.34514066] [1.         0.28323307]\\n\",\n      \"7.190322947344272 [1.33338819 0.20398601] [1.29896954 0.18313929] [1.22096887 0.14425465]\\n\",\n      \"8.848243226737866 [1.         0.36463984] [1.         0.30499512] [1.        0.2532191]\\n\",\n      \"9.192714346804564 [1.49756005 0.38641918] [1.47873111 0.33758922] [1.36019747 0.2924594 ]\\n\",\n      \"11.319870167575406 [1.1507284  0.25025712] [1.10234282 0.21313653] [1.07017249 0.14967353]\\n\",\n      \"3.122518940096929 [1.         0.25684418] [1.        0.2145787] [1.01112151 0.1713307 ]\\n\",\n      \"4.973297440166892 [1.75629566 0.26087785] [1.51903931 0.20265438] [1.30394731 0.12126478]\\n\",\n      \"17.577524487473408 [1.         0.32556971] [1.         0.30692654] [1.32407937 0.14465335]\\n\",\n      \"9.477049842891258 [1.31975522 0.31669314] [1.2151192  0.28211103] [1.13099365 0.23748448]\\n\",\n      \"8.08814505673164 [1.07652963 0.34287907] [1.         0.32370187] [1.         0.28736337]\\n\",\n      \"9.65164923753427 [1.19063162 0.25259696] [1.18098305 0.22452833] [1.1382278  0.21833746]\\n\",\n      \"3.401865598938056 [1.03189485 0.19275967] [1.03970229 0.15411936] [1.05574905 0.10739656]\\n\",\n      \"14.176112013280372 [1.         0.37650505] [1.         0.34396706] [1.         0.30638204]\\n\",\n      \"6.688746200490611 [1.32125092 0.48171739] [1.19045539 0.45402005] [1.03386745 0.41444949]\\n\",\n      \"4.1057804673811615 [1.         0.10640612] [1.         0.08556885] [1.         0.05988287]\\n\",\n      \"6.67515108253583 [1.30299822 0.21644816] [1.2627708  0.17855494] [1.2811692  0.08289794]\\n\",\n      \"5.618221600263949 [1.13202922 0.20517403] [1.11029442 0.1728153 ] [1.11432531 0.09781313]\\n\",\n      \"9.683319846054019 [1.41664575 0.10072695] [1.3350207 0.0975795] [1.35620321 0.02201072]\\n\",\n      \"9.440038887014085 [1.18255191 0.45836029] [1.22659939 0.39898854] [1.0523846  0.37111137]\\n\",\n      \"3.101464387779503 [1.00310226 0.18617054] [1.01341032 0.16490255] [1.         0.14288429]\\n\",\n      \"11.430687928780806 [1.         0.14612842] [1.         0.13461617] [1.08796245 0.05549187]\\n\",\n      \"13.27307578626223 [2.03265487 0.34461573] [1.89730784 0.31029694] [1.85614307 0.22885712]\\n\",\n      \"13.522674934918868 [1.24809083 0.26623166] [1.15286954 0.24685906] [1.04216024 0.20770993]\\n\",\n      \"10.723073515894926 [1.         0.46249963] [1.         0.43080009] [1.         0.38489059]\\n\",\n      \"9.906265694175751 [1.62675252 0.08227807] [1.49673376 0.07923181] [1.35811367 0.06423116]\\n\",\n      \"8.396315197343291 [1.         0.43845084] [1.05420866 0.39183131] [1.05239784 0.3264069 ]\\n\",\n      \"5.9690038195265505 [1.         0.32685424] [1.         0.30352812] [1.         0.26298348]\\n\",\n      \"7.840747356295273 [1.3913322  0.19861392] [1.34101817 0.1668658 ] [1.26671578 0.15635534]\\n\",\n      \"6.827495181836394 [2.01031885 0.28656483] [1.8092051  0.24054918] [1.50658006 0.22962397]\\n\",\n      \"9.895341011398294 [1.       0.417134] [1.         0.37778122] [1.        0.3224807]\\n\",\n      \"4.236715657828287 [1.72151572 0.27422812] [1.50583213 0.22894572] [1.13164725 0.2143216 ]\\n\",\n      \"9.78882843005207 [1.         0.38983699] [1.06821947 0.29867519] [1.02002297 0.21792161]\\n\",\n      \"11.323618689434669 [1.         0.28905041] [1.         0.25813764] [1.         0.22025564]\\n\",\n      \"2.827980676500224 [1.08528143 0.27993667] [1.0424405  0.24880152] [1.         0.21143698]\\n\",\n      \"7.388237834536167 [1.02595207 0.21509063] [1.00603128 0.18301106] [1.01643689 0.12862386]\\n\",\n      \"11.837966231279758 [1.         0.37373634] [1.         0.30779492] [1.        0.2218484]\\n\",\n      \"6.727837508030034 [1.79801516 0.32114686] [1.7475416  0.25240467] [1.46188216 0.25007253]\\n\",\n      \"9.794021663238883 [1.         0.42662672] [1.07878589 0.365327  ] [1.0410405 0.3164553]\\n\",\n      \"4.083904398142607 [1.79993235 0.25247301] [1.5791356  0.20800874] [1.12746493 0.20749548]\\n\",\n      \"5.721784128953531 [1.12182882 0.28560624] [1.12424996 0.26081497] [1.02604428 0.24848895]\\n\",\n      \"8.621892828153207 [1.5241035  0.17616859] [1.40255741 0.16787739] [1.27289634 0.16637972]\\n\",\n      \"5.4219216877500385 [1.46268234 0.37838065] [1.43721686 0.33107937] [1.30880109 0.27054054]\\n\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"12.92374294802015 [1.         0.43825739] [1.         0.40630202] [1.         0.35982225]\\n\",\n      \"12.12330184772465 [1.         0.31878988] [1.         0.29516628] [1.         0.26243926]\\n\",\n      \"13.02981427411799 [1.22953443 0.27054121] [1.13130629 0.25406708] [1.02533223 0.21277484]\\n\",\n      \"11.11366702434372 [1.0001778  0.15644102] [1.        0.1407532] [1.07866234 0.05775561]\\n\",\n      \"13.643759331445702 [2.14823332 0.33184356] [2.03092151 0.29867382] [1.92904248 0.22367221]\\n\",\n      \"7.744886110161616 [1.20591221 0.45744155] [1.24288846 0.39719454] [1.08658366 0.36137132]\\n\",\n      \"10.809995204901764 [1.18865417 0.15685477] [1.08331153 0.15662978] [1.14261109 0.08033617]\\n\",\n      \"3.7238308588415934 [1.00328037 0.20303921] [1.00661734 0.18357931] [1.         0.15610408]\\n\",\n      \"7.936937679628445 [1.07265777 0.29389113] [1.06746621 0.25758217] [1.09819489 0.17303536]\\n\",\n      \"4.283265055990391 [1.        0.1119575] [1.         0.08966731] [1.         0.06296653]\\n\",\n      \"6.534887905244455 [1.41717386 0.15807677] [1.33269811 0.14386502] [1.15669023 0.11879725]\\n\",\n      \"5.944232542753346 [1.3929997  0.46064437] [1.2126758  0.43505099] [1.03033525 0.39986111]\\n\",\n      \"8.727786256388198 [1.         0.39480769] [1.         0.35426979] [1.         0.31290909]\\n\",\n      \"3.3846767085980463 [1.00647622 0.20466953] [1.01157291 0.16593815] [1.05980113 0.1060288 ]\\n\",\n      \"7.803723665864124 [1.         0.38890162] [1.         0.34317824] [1.         0.30201535]\\n\",\n      \"9.193485071927222 [1.22529701 0.23804733] [1.19621973 0.22068826] [1.17930709 0.20824439]\\n\",\n      \"16.949125286054716 [1.         0.34446065] [1.         0.31762179] [1.33182696 0.15283466]\\n\",\n      \"3.8748364223382645 [1.0611979  0.22382414] [1.        0.2178958] [1.12213093 0.13030333]\\n\",\n      \"11.04633248837437 [1.31633276 0.20016345] [1.17997035 0.18777593] [1.1162778  0.13053015]\\n\",\n      \"9.276156641897078 [1.07789592 0.47502514] [1.11088407 0.41493342] [1.08391157 0.35643613]\\n\",\n      \"10.158945558624081 [1.         0.36348312] [1.         0.31503743] [1.        0.2628768]\\n\",\n      \"14.34898535212909 [1.         0.38492621] [1.         0.34894339] [1.         0.29041724]\\n\",\n      \"4.396947412245824 [1.04215175 0.11467678] [1.07034018 0.07286891] [1.02961519 0.06183437]\\n\",\n      \"11.136538148338126 [1.90098929 0.04814792] [1.5652092  0.06581793] [1.36260503 0.06558594]\\n\",\n      \"14.846407868283405 [1.         0.40598203] [1.         0.38136962] [1.         0.34648215]\\n\",\n      \"3.489638923048433 [1.42120371 0.33105368] [1.53614666 0.25299199] [1.33302736 0.21516879]\\n\",\n      \"15.151601952965605 [1.75116773 0.26517765] [1.52284581 0.26063396] [1.71683974 0.12902499]\\n\",\n      \"7.474191432710572 [1.46522048 0.37145897] [1.16752892 0.36515838] [1.         0.32322039]\\n\",\n      \"13.635094007444565 [1.         0.47359954] [1.0029611  0.39964454] [1.         0.35599858]\\n\",\n      \"11.971244137120992 [1.72424873 0.09580737] [1.62864594 0.07245187] [1.37353869 0.09478861]\\n\",\n      \"14.744004694629867 [1.47259561 0.42305383] [1.46576631 0.35352895] [1.35692132 0.29742234]\\n\",\n      \"7.7499555168184875 [1.42205944 0.13015136] [1.45947697 0.0760536 ] [1.39622789 0.05506497]\\n\",\n      \"10.318015080425308 [1.22651317 0.19469698] [1.2040185  0.18711634] [1.16993615 0.16435758]\\n\",\n      \"7.508101468781912 [1.         0.27320416] [1.         0.24114559] [1.         0.18801963]\\n\",\n      \"7.353069512400742 [1.07189384 0.40878328] [1.06308897 0.37137663] [1.05266062 0.32375307]\\n\",\n      \"11.307551184664831 [1.         0.28636899] [1.         0.26937179] [1.         0.24237139]\\n\",\n      \"10.636613739433855 [1.21212062 0.11645089] [1.19701997 0.10758946] [1.20047315 0.08228245]\\n\",\n      \"12.87306339200131 [1.22334068 0.26602585] [1.12109215 0.25073462] [1.19059815 0.1187465 ]\\n\",\n      \"16.46473971972104 [1.42511277 0.37471813] [1.3797633  0.32869782] [1.33329276 0.2587746 ]\\n\",\n      \"7.7680470356588955 [1.         0.35555614] [1.         0.33216718] [1.         0.28608048]\\n\",\n      \"4.438435829783888 [1.1681177  0.11714117] [1.18395561 0.08094577] [1.14427211 0.05225899]\\n\",\n      \"5.014186239267677 [1.         0.27324214] [1.         0.25595753] [1.         0.21792883]\\n\",\n      \"5.93390200897512 [1.91139199 0.01259635] [1.48077759 0.10922466] [1.         0.29773391]\\n\",\n      \"9.32348760898458 [1.89333089 0.39569778] [1.55067934 0.3967098 ] [1.33230293 0.37010351]\\n\",\n      \"12.903284216394288 [1.         0.38405575] [1.         0.35803423] [1.         0.31370152]\\n\",\n      \"11.737460110071973 [1.         0.32668573] [1.         0.30409448] [1.         0.26355711]\\n\",\n      \"12.764559174176856 [1.28387258 0.35579261] [1.14542861 0.32634888] [1.00283855 0.28218503]\\n\",\n      \"7.358532248887396 [1.07442551 0.26832179] [1.10202569 0.20488099] [1.05784295 0.1513299 ]\\n\",\n      \"3.9191805140817237 [1.09084507 0.26922418] [1.02677945 0.25541651] [1.         0.22129397]\\n\",\n      \"8.851650846913161 [1.        0.3687964] [1.         0.34034031] [1.         0.29227185]\\n\",\n      \"6.189058073514696 [1.23263448 0.31059963] [1.23537597 0.25595354] [1.10422191 0.24668319]\\n\",\n      \"8.846340005888473 [1.         0.28203295] [1.         0.24657353] [1.         0.19505829]\\n\",\n      \"5.855337803107138 [1.         0.31724028] [1.         0.29185951] [1.         0.24919558]\\n\",\n      \"14.37666345589674 [1.         0.35874386] [1.        0.3136187] [1.         0.24626372]\\n\",\n      \"11.560589828261861 [1.14596985 0.5       ] [1.         0.48494053] [1.         0.41586715]\\n\",\n      \"9.17277299686883 [1.         0.37383266] [1.         0.33817492] [1.       0.288125]\\n\",\n      \"13.109723379400041 [1.         0.34576873] [1.         0.30178404] [1.        0.2354756]\\n\",\n      \"13.457486269408257 [1.14818376 0.47345392] [1.21100018 0.42181602] [1.19799818 0.36756067]\\n\",\n      \"7.780002402065833 [1.47235295 0.21392684] [1.34075937 0.18970313] [1.16090543 0.15055719]\\n\",\n      \"4.515234402687749 [1.30800007 0.05901288] [1.27854246 0.06173632] [1.26136428 0.04963119]\\n\",\n      \"7.7002553235279 [1.3709947  0.40154594] [1.16884158 0.39594325] [1.         0.38857235]\\n\",\n      \"8.519950021502659 [2.56890745 0.38903405] [1.99093146 0.40651462] [1.48469118 0.40834175]\\n\",\n      \"11.096057531064517 [1.         0.32453121] [1.         0.29025584] [1.         0.24166803]\\n\",\n      \"11.793027738096868 [1.         0.40430373] [1.         0.35858496] [1.         0.28624796]\\n\",\n      \"10.858903574152004 [1.19549563 0.27543131] [1.1129744  0.25376468] [1.05648035 0.202586  ]\\n\",\n      \"6.709065599768365 [1.3076391 0.1361815] [1.30242625 0.08963543] [1.24594118 0.07549374]\\n\",\n      \"7.9727635055249015 [1.21539292 0.18875852] [1.19715729 0.1650433 ] [1.16333386 0.10402095]\\n\",\n      \"16.899611590422623 [1.06484678 0.3397656 ] [1.         0.31900987] [1.12891387 0.17694057]\\n\",\n      \"6.636931715410257 [1.47576206 0.30595888] [1.44386168 0.26626175] [1.21528856 0.24989117]\\n\",\n      \"12.461808410693672 [1.         0.40222813] [1.         0.36594668] [1.         0.30994769]\\n\",\n      \"7.228626965483202 [1.08957594 0.30282639] [1.10278698 0.27700384] [1.07887094 0.24984452]\\n\",\n      \"4.717126212878538 [1.         0.27271711] [1.45203883e+00 1.36755657e-03] [1.2656421  0.01236516]\\n\",\n      \"8.337957903790556 [1.44915374 0.30760317] [1.30228489 0.29604403] [1.51546082 0.15115157]\\n\",\n      \"4.310434864075449 [1.04860806 0.23784963] [ 1.42965787 -0.00999866] [1.24416184 0.00394949]\\n\",\n      \"8.954482582718153 [1.87589561 0.23543362] [1.46069088 0.25948162] [1.82186006 0.09131086]\\n\",\n      \"10.76756632971838 [1.        0.4056859] [1.         0.37036067] [1.         0.31855702]\\n\",\n      \"10.164545587235924 [1.5658485  0.37006108] [1.28072399 0.38086746] [1.02952132 0.37038356]\\n\",\n      \"6.516128297771173 [1.04107228 0.31526783] [1.07097401 0.28386317] [1.05898164 0.25225454]\\n\",\n      \"19.348904782894852 [1.07353974 0.34588991] [1.         0.32892898] [1.12542758 0.19515951]\\n\",\n      \"6.224941790567825 [1.3365501  0.17358007] [1.30171357 0.14523596] [1.22255449 0.09457928]\\n\",\n      \"5.668600164185342 [1.31913522 0.1013403 ] [1.28318693 0.06483723] [1.1684233  0.05374388]\\n\",\n      \"5.879670066249189 [1.28783993 0.17444356] [1.24822519 0.14295736] [1.26440323 0.05057444]\\n\",\n      \"7.0295185686981245 [1.32749237 0.14445933] [1.34649887 0.08903797] [1.28815073 0.07531211]\\n\",\n      \"13.432121950887783 [1.09368324 0.26136207] [1.06315197 0.23826839] [1.02415159 0.194419  ]\\n\",\n      \"7.431158701186877 [1.52387026 0.45575523] [1.26768842 0.46138969] [1.00912014 0.45951464]\\n\",\n      \"9.133075847433641 [1.         0.32763083] [1.        0.2952476] [1.       0.247423]\\n\",\n      \"8.086363101595888 [1.         0.42089985] [1.         0.37160776] [1.         0.29277675]\\n\",\n      \"3.5368746256625876 [1.3152582  0.05525126] [1.28577703 0.05892926] [1.26281295 0.04770792]\\n\",\n      \"8.912787507700697 [1.34011579 0.41647153] [1.32052103 0.38440041] [1.11624175 0.38066937]\\n\",\n      \"7.251552657740937 [1.44296037 0.21621846] [1.33010513 0.18317351] [1.15585674 0.14626113]\\n\",\n      \"11.852161540030494 [1.         0.36114575] [1.         0.31125812] [1.         0.23774902]\\n\",\n      \"13.012564640990725 [1.        0.4774848] [1.08504104 0.42621422] [1.08582728 0.38513799]\\n\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"10.399040863387082 [1.         0.34256646] [1.         0.31658962] [1.        0.2745839]\\n\",\n      \"15.365551767647597 [1.        0.3479788] [1.06557821 0.29742744] [1.05550735 0.2351434 ]\\n\",\n      \"5.530280188449658 [1.32775307 0.23470935] [1.26303032 0.22132559] [1.14345919 0.20878732]\\n\",\n      \"8.872598365341643 [1.         0.29197827] [1.         0.25541775] [1.         0.20189065]\\n\",\n      \"4.363754424405304 [1.08907136 0.26384361] [1.04647512 0.24163538] [1.         0.21388896]\\n\",\n      \"13.529900174983718 [1.         0.41353479] [1.        0.4040192] [1.         0.38084503]\\n\",\n      \"6.579937201837601 [1.09192602 0.35805918] [1.11899994 0.30419861] [1.         0.29325943]\\n\",\n      \"11.521328965457522 [1.         0.32320311] [1.         0.30077645] [1.         0.25973495]\\n\",\n      \"9.514036722453081 [1.         0.14215642] [1.         0.12358407] [1.01925728 0.08379428]\\n\",\n      \"13.70804099395271 [1.         0.39063537] [1.         0.36265178] [1.         0.31751861]\\n\",\n      \"11.592487347281375 [1.77474634 0.41289645] [1.42922063 0.41914452] [1.25133301 0.3907739 ]\\n\",\n      \"4.588247546140393 [1.        0.1979672] [1.         0.18163717] [1.         0.15291255]\\n\",\n      \"7.827889851702488 [1.         0.36113917] [1.         0.33438351] [1.         0.29454496]\\n\",\n      \"11.02283657104899 [1.         0.44227969] [2.70764757 0.14443202] [2.08253508 0.10108426]\\n\",\n      \"4.570555226809893 [1.12471671 0.13227695] [1.15569754 0.0912686 ] [1.12649879 0.05929105]\\n\",\n      \"9.720891340088501 [1.11818475 0.28343482] [1.09872754 0.2782892 ] [1.03840465 0.27418947]\\n\",\n      \"6.2448992341603615 [1.18327301 0.11860671] [1.16407577 0.11256122] [1.14903747 0.09490972]\\n\",\n      \"19.334696900916832 [1.         0.46099643] [1.        0.4122927] [1.14149199 0.31152665]\\n\",\n      \"13.052606369671384 [1.20489344 0.2677766 ] [1.14234011 0.24176071] [1.2235349  0.10701297]\\n\",\n      \"7.488423484356947 [1.24143361 0.37978145] [1.13644446 0.36971169] [1.10550541 0.33195308]\\n\",\n      \"7.878020226024404 [1.         0.27620113] [1.         0.24541771] [1.         0.19496444]\\n\",\n      \"9.843909234368232 [1.27755267 0.19437357] [1.25222751 0.1835578 ] [1.20303539 0.16027326]\\n\",\n      \"10.41734972168427 [1.2781173  0.16988161] [1.25269049 0.16499974] [1.21396005 0.13883092]\\n\",\n      \"8.57782745968769 [1.         0.18388329] [1.12097025 0.13244987] [1.06700681 0.13676794]\\n\",\n      \"6.752323987887464 [1.         0.41286498] [1.        0.3741202] [1.         0.31717428]\\n\",\n      \"6.894033251838395 [1.14965681 0.23590677] [1.13442137 0.23065988] [1.07788449 0.22178575]\\n\",\n      \"4.533588806434659 [1.         0.42401566] [1.66369299 0.13481584] [1.38499166 0.11752162]\\n\",\n      \"17.447802191463197 [1.33086438 0.41907321] [1.35122617 0.36268071] [1.43848487 0.26100588]\\n\",\n      \"8.855210245864539 [1.44391281 0.38613494] [1.28211709 0.35994396] [1.04366756 0.32924053]\\n\",\n      \"8.593938087065613 [1.66144589 0.03043435] [1.17368481 0.18860503] [1.        0.3112879]\\n\",\n      \"4.516192672820185 [1.         0.17315768] [1.         0.16027499] [1.         0.13601338]\\n\",\n      \"9.930369039926079 [1.         0.45657104] [1.         0.41818039] [1.        0.3525668]\\n\",\n      \"11.750623067675095 [1.         0.38959995] [1.         0.36583342] [1.       0.324724]\\n\",\n      \"8.801304702139626 [1.86076771 0.39463772] [1.4964728  0.39759935] [1.27639272 0.36993752]\\n\",\n      \"10.42730869983231 [1.         0.13676754] [1.        0.1192226] [1.04220164 0.0416786 ]\\n\",\n      \"7.955986863592911 [1.        0.4017892] [1.         0.37192542] [1.         0.32327928]\\n\",\n      \"4.797502828089043 [1.50201568 0.25921432] [1.54890249 0.19463649] [1.37743284 0.18441868]\\n\",\n      \"9.819222389038643 [1.         0.34934257] [1.         0.32134333] [1.         0.27584404]\\n\",\n      \"8.26535325760776 [1.06513529 0.09520454] [1.03612411 0.08847213] [1.01449466 0.07178682]\\n\",\n      \"4.141517173120085 [1.39447253 0.18245722] [1.26644478 0.18135569] [1.17966932 0.16226805]\\n\",\n      \"6.781159710831368 [1.02872215 0.26257408] [1.05882205 0.19904078] [1.03839647 0.14275096]\\n\",\n      \"11.396194060489748 [1.02887201 0.42703845] [1.         0.38434703] [1.        0.2990352]\\n\",\n      \"6.514048582404852 [1.         0.31228228] [1.         0.28870107] [1.         0.24761954]\\n\",\n      \"12.263896963901542 [1.         0.32230866] [1.         0.29873797] [1.        0.2590367]\\n\",\n      \"5.660882528116615 [1.         0.30433113] [1.         0.28598432] [1.         0.24896655]\\n\",\n      \"12.641431911900085 [1.         0.49750816] [1.         0.45185861] [1.         0.37262997]\\n\",\n      \"6.785688916406413 [1.         0.30198109] [1.         0.28480017] [1.         0.25618031]\\n\",\n      \"14.816778587103881 [1.1193813  0.46257033] [1.14025848 0.4076001 ] [1.15249545 0.36044309]\\n\",\n      \"14.0008259359064 [1.         0.33486403] [1.         0.29191746] [1.         0.22265205]\\n\",\n      \"5.486457612990889 [1.28396691 0.38449935] [1.15504614 0.34593028] [1.         0.29126258]\\n\",\n      \"10.525596361418229 [1.         0.35591955] [1.         0.32451637] [1.        0.2782514]\\n\",\n      \"5.797371127000429 [2.64080811 0.43203536] [2.08242611 0.44599074] [1.43901104 0.44696342]\\n\",\n      \"8.29309385803627 [1.43765078 0.38701073] [1.30139209 0.37205705] [1.09309882 0.37037415]\\n\",\n      \"5.4541989864053075 [1.2645468  0.06761498] [1.24424396 0.0679545 ] [1.21825074 0.0590681 ]\\n\",\n      \"12.111974237275414 [1.12780225 0.27006465] [1.08212607 0.24573226] [1.03817795 0.1977042 ]\\n\",\n      \"10.035389830912521 [1.33452756 0.11448563] [1.27556688 0.09970234] [1.25475358 0.05040068]\\n\",\n      \"7.122328355171524 [1.00631411 0.20625468] [1.00821955 0.18043911] [1.00000001 0.13051789]\\n\",\n      \"6.74398160124192 [1.33823143 0.14356404] [1.34987935 0.09204396] [1.26482916 0.08544395]\\n\",\n      \"4.541000483994441 [1.51780067 0.14736122] [1.36745694 0.12999866] [1.21296419 0.1007455 ]\\n\",\n      \"16.44394747942648 [1.161461   0.30433507] [1.02411477 0.30137427] [1.15681405 0.15734204]\\n\",\n      \"12.30836024100948 [1.        0.3582024] [1.         0.33574834] [1.         0.23375577]\\n\",\n      \"4.921298954195917 [1.         0.26835638] [1.43996426e+00 2.10709231e-04] [1.25909389 0.01114293]\\n\",\n      \"6.800284197181853 [1.         0.33739754] [1.         0.31960835] [1.         0.28985652]\\n\",\n      \"6.047774476938553 [1.35031977 0.30326913] [1.33688011 0.25626945] [1.17761302 0.22920831]\\n\",\n      \"12.285616934681238 [1.         0.41184596] [1.         0.37458676] [1.         0.32193021]\\n\",\n      \"7.4385668492911154 [1.         0.34272029] [1.         0.31921037] [1.         0.28272231]\\n\",\n      \"13.748919755251686 [1.         0.38760815] [1.        0.3527708] [1.         0.29410875]\\n\",\n      \"7.573395290740324 [1.54059085 0.28434904] [1.45211232 0.25511036] [1.23431808 0.23588014]\\n\",\n      \"12.827942385184194 [1.         0.34133679] [1.         0.32517513] [1.08081162 0.2048545 ]\\n\",\n      \"5.315138803930238 [1.         0.26630479] [1.46135346 0.00174964] [1.26930529 0.01257144]\\n\",\n      \"18.445641513586367 [1.         0.32009183] [1.         0.28623857] [1.10184002 0.15859606]\\n\",\n      \"6.9039921609184 [1.         0.27452515] [1.         0.23706619] [1.        0.1808617]\\n\",\n      \"6.933856541495104 [1.26836311 0.17292964] [1.23943197 0.13529919] [1.18344933 0.1190079 ]\\n\",\n      \"9.1754855532081 [1.32012519 0.10581155] [1.26743775 0.08459798] [1.21211291 0.05402826]\\n\",\n      \"7.95999259855081 [1.        0.2280663] [1.         0.19946997] [1.         0.14746733]\\n\",\n      \"12.90117192118057 [1.07258321 0.27879639] [1.03918096 0.25109079] [1.00187957 0.20184881]\\n\",\n      \"8.498572882655235 [1.         0.48467719] [1.0507401  0.43429832] [1.         0.40156388]\\n\",\n      \"4.968862936015541 [1.27935359 0.06278018] [1.25493564 0.06435664] [1.23059556 0.05567309]\\n\",\n      \"7.837028390588623 [1.         0.40538159] [1.         0.35924818] [1.         0.28033263]\\n\",\n      \"8.370309345977399 [1.         0.33104577] [1.         0.29727104] [1.         0.24940342]\\n\",\n      \"5.786796816872193 [2.78718743 0.43468676] [2.15784544 0.4521613 ] [1.49092268 0.45109614]\\n\",\n      \"14.888725983674073 [1.         0.47870681] [1.         0.42638215] [1.         0.38860656]\\n\",\n      \"12.042469117903966 [1.         0.34151152] [1.         0.29692816] [1.         0.22580898]\\n\",\n      \"11.371645567619755 [1.         0.49242957] [1.         0.44780231] [1.         0.37439217]\\n\",\n      \"6.448794133669505 [1.         0.30517096] [1.         0.28617864] [1.         0.25176324]\\n\",\n      \"5.524627378888152 [1.         0.30355241] [1.         0.28945991] [1.         0.25898903]\\n\",\n      \"11.51569566853841 [1.         0.33297855] [1.         0.30788312] [1.         0.26779871]\\n\",\n      \"5.773443165884961 [1.         0.32568689] [1.         0.30010188] [1.         0.25709086]\\n\",\n      \"6.748136657813239 [1.01374175 0.25088664] [1.04548775 0.18812538] [1.03199114 0.13496343]\\n\",\n      \"12.052698614924967 [1.39627125 0.35087073] [1.23224636 0.32969554] [1.08285179 0.28375729]\\n\",\n      \"10.060533011831275 [1.         0.34598965] [1.         0.31133551] [1.         0.26383932]\\n\",\n      \"9.199727336978794 [1.         0.33325842] [1.         0.30006165] [1.         0.25371227]\\n\",\n      \"6.709969394594997 [1.10918757 0.06658346] [1.06499626 0.06553977] [1.02883924 0.05662035]\\n\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"9.6060888160924 [1.         0.12829732] [1.         0.10727965] [1.         0.07368046]\\n\",\n      \"7.822266741030743 [2.0912904  0.37328705] [1.66836271 0.37461134] [1.36645524 0.35356312]\\n\",\n      \"10.359787060458036 [1.         0.37363668] [1.         0.34732112] [1.        0.2977534]\\n\",\n      \"11.787197668597528 [1.         0.37676134] [1.         0.35483447] [1.         0.31241321]\\n\",\n      \"6.347286602421171 [1.31041515 0.05925655] [1.29216435 0.06166785] [1.26878466 0.05273866]\\n\",\n      \"3.266710456261983 [1.         0.16265129] [1.         0.15089868] [1.         0.12700483]\\n\",\n      \"9.17760253205858 [1.65796412 0.02662391] [1.15413931 0.19511157] [1.         0.31638017]\\n\",\n      \"5.901836666555211 [1.50092084 0.37688789] [1.27476517 0.3569254 ] [1.12850671 0.30477134]\\n\",\n      \"16.567119942126844 [1.28121459 0.4361741 ] [1.33713373 0.3694556 ] [1.48853805 0.25064159]\\n\",\n      \"14.200468274379704 [1.46160727 0.21797348] [1.26487462 0.21894439] [1.29068252 0.09228189]\\n\",\n      \"9.558914669941702 [1.17834556 0.11091971] [1.11168435 0.13118322] [1.10126591 0.10788446]\\n\",\n      \"8.630999472645442 [1.10455925 0.25939936] [1.07791108 0.25818788] [1.01876008 0.25475522]\\n\",\n      \"7.1896574173042405 [1.        0.4394725] [1.69151733 0.15936391] [1.40745523 0.14090692]\\n\",\n      \"9.028997222445492 [1.35449893 0.11153744] [1.3897975  0.08635315] [1.23613002 0.11002226]\\n\",\n      \"9.540562872301349 [1.3576914  0.17826378] [1.3231857  0.16968261] [1.27488166 0.14102864]\\n\",\n      \"15.217502969341245 [1.         0.46800261] [1.03400325 0.39430012] [1.         0.36108749]\\n\",\n      \"8.149472820781698 [1.46439189 0.15456699] [1.41640802 0.15058807] [1.35129618 0.12745409]\\n\",\n      \"6.976944992161395 [1.         0.41820539] [1.63197955 0.20528542] [1.30790648 0.18357009]\\n\",\n      \"8.062994200864372 [1.2590772  0.18548839] [1.23136575 0.18143837] [1.17911986 0.17282262]\\n\",\n      \"13.844743802537836 [1.2015369  0.15183412] [1.19427918 0.13726769] [1.18246524 0.11350421]\\n\",\n      \"14.400292721241726 [1.43332449 0.2235953 ] [1.29722739 0.20826826] [1.32205969 0.08825859]\\n\",\n      \"9.004881522367741 [2.32464637 0.28278733] [1.89162317 0.28131392] [1.59227461 0.24001558]\\n\",\n      \"6.107589824184695 [1.39960932 0.31553911] [1.37343198 0.27532731] [1.27943831 0.24714296]\\n\",\n      \"13.48237112635666 [1.         0.36446671] [1.         0.33977545] [1.         0.29601266]\\n\",\n      \"3.121268510095921 [1.         0.18391253] [1.         0.17488635] [1.         0.14762028]\\n\",\n      \"11.978946731228046 [1.86766493 0.38186767] [1.51454431 0.39137673] [1.27668655 0.37908934]\\n\",\n      \"6.865283541115352 [1.         0.15293961] [1.         0.12481174] [1.        0.1027406]\\n\",\n      \"7.281891843928655 [1.69451487 0.01952414] [1.23764544 0.15298789] [1.         0.27897591]\\n\",\n      \"5.298772915637459 [1.         0.24827961] [1.         0.23037025] [1.        0.1959408]\\n\",\n      \"3.8077487270183905 [1.16329354 0.24901332] [1.0779487 0.2385194] [1.00363016 0.21778446]\\n\",\n      \"10.497606690401419 [1.         0.34818521] [1.         0.31596487] [1.         0.27036681]\\n\",\n      \"7.158392552942906 [1.01385627 0.38510843] [1.04847678 0.33300206] [1.        0.3031797]\\n\",\n      \"10.846809611522731 [1.         0.38793241] [1.         0.34907699] [1.         0.28209224]\\n\",\n      \"7.490903894678285 [1.         0.26030496] [1.08122001 0.17188334] [1.05893602 0.11828488]\\n\",\n      \"9.852946364225987 [1.42228082 0.19200603] [1.36081795 0.18114805] [1.32457667 0.13506045]\\n\",\n      \"9.813044943909494 [1.         0.13507287] [1.        0.1211452] [1.06985331 0.03902373]\\n\",\n      \"13.34587581319962 [1.         0.30546792] [1.         0.28546389] [1.         0.24660574]\\n\",\n      \"7.418143411786704 [2.30965378 0.27421756] [2.36527214 0.22601405] [1.83067958 0.20473945]\\n\",\n      \"6.487563025349676 [1.         0.27149387] [1.         0.25961742] [1.         0.23347115]\\n\",\n      \"6.222757756016874 [1.        0.3155778] [1.        0.2964766] [1.         0.26161223]\\n\",\n      \"11.059078872787278 [1.79205417 0.45790815] [1.50336918 0.43939687] [1.15875599 0.42807139]\\n\",\n      \"12.627232736967574 [2.43846299 0.48689282] [1.8662622  0.49410938] [1.4464353  0.47503725]\\n\",\n      \"8.497852968895268 [1.         0.33173961] [1.         0.29790762] [1.         0.24960098]\\n\",\n      \"4.773945892232486 [1.38559259 0.37218966] [1.14653168 0.3521637 ] [1.05064959 0.28242611]\\n\",\n      \"4.9798440733126625 [1.2818747  0.06413411] [1.26715682 0.06259901] [1.25116844 0.05028596]\\n\",\n      \"7.215982714193697 [1.22008475 0.43899845] [1.1838212  0.40659116] [1.03220971 0.39996869]\\n\",\n      \"12.222665510347309 [1.         0.35190148] [1.         0.30240314] [1.         0.23161851]\\n\",\n      \"14.546352314783011 [1.01211239 0.47213658] [1.02836638 0.42257262] [1.04885107 0.38179193]\\n\",\n      \"6.232905485930497 [1.14838793 0.21144955] [1.14887286 0.174099  ] [1.14116221 0.10557713]\\n\",\n      \"6.2462598538040845 [1.28307071 0.11209189] [1.19406696 0.09839082] [1.20154404 0.03788191]\\n\",\n      \"12.049837214204034 [1.17801886 0.26021165] [1.10615307 0.23640304] [1.04174162 0.19228388]\\n\",\n      \"4.35580424954309 [1.         0.26060752] [ 1.43414028 -0.00861912] [1.24491131 0.00611205]\\n\",\n      \"8.471650763073383 [1.56754364 0.2797233 ] [1.33000344 0.27946502] [1.59653944 0.12456281]\\n\",\n      \"7.49044012651744 [1.32171485 0.39542063] [1.26162582 0.34219468] [1.22760822 0.2795295 ]\\n\",\n      \"10.786327365502226 [1.         0.42439941] [1.         0.38285526] [1.        0.3303409]\\n\",\n      \"10.373536489404659 [1.         0.25571275] [1.02855063 0.19742803] [1.05409558 0.14161136]\\n\",\n      \"21.022477749453145 [1.11983787 0.33305407] [1.02719116 0.32128849] [1.19594603 0.17525151]\\n\",\n      \"5.7531032804037325 [1.11015019 0.0588051 ] [1.06545661 0.05506   ] [1.02919721 0.04600535]\\n\",\n      \"18.915284201266825 [1.12302054 0.32809132] [1.01549009 0.32142077] [1.17509954 0.17326784]\\n\",\n      \"10.120576070558265 [1.         0.43104948] [1.         0.38759719] [1.         0.33643654]\\n\",\n      \"10.366369407611174 [1.2204956  0.43281109] [1.17844132 0.37696907] [1.15221851 0.31668571]\\n\",\n      \"8.920615707023098 [1.215459   0.37204049] [1.13545431 0.34796987] [1.         0.32275485]\\n\",\n      \"10.309453079836256 [1.         0.28581225] [1.         0.23793582] [1.         0.18713733]\\n\",\n      \"9.98892516276331 [1.         0.35439137] [1.         0.33185869] [1.         0.29803357]\\n\",\n      \"6.165809009113636 [1.         0.24603162] [ 1.43675457 -0.00240027] [1.24853703 0.00712537]\\n\",\n      \"8.393433910101315 [1.42325199 0.30029363] [1.27237161 0.28883931] [1.47440567 0.14601451]\\n\",\n      \"12.787309279590836 [1.1576845  0.25694987] [1.10474044 0.23022065] [1.05309267 0.18348645]\\n\",\n      \"7.2769603440583035 [1.         0.23596967] [1.02008553 0.18715022] [1.05534694 0.10902808]\\n\",\n      \"7.108041328152496 [1.35727395 0.10880889] [1.28569031 0.10061569] [1.30496199 0.02775571]\\n\",\n      \"13.216900123439 [1.         0.34101936] [1.         0.29523019] [1.        0.2311334]\\n\",\n      \"14.692874439645578 [1.         0.48535927] [1.         0.43252866] [1.03359864 0.38608543]\\n\",\n      \"4.304520284885788 [1.31688245 0.05907244] [1.29540368 0.0611252 ] [1.27733039 0.04833066]\\n\",\n      \"6.330490988906221 [1.04545909 0.46497502] [1.01102258 0.43032254] [1.         0.38169003]\\n\",\n      \"9.685975035829152 [1.         0.33264391] [1.        0.2974268] [1.         0.24907809]\\n\",\n      \"15.702773916235305 [1.  0.5] [1.02657099 0.44381973] [1.10001176 0.35810512]\\n\",\n      \"7.853615334358044 [1.         0.28074219] [1.         0.26840914] [1.         0.24217976]\\n\",\n      \"5.975634138326699 [1.         0.31659836] [1.         0.29810188] [1.         0.26348462]\\n\",\n      \"9.916002163802123 [1.59868211 0.38111511] [1.52662322 0.33446563] [1.36835491 0.29793856]\\n\",\n      \"12.409547158848543 [1.         0.31865139] [1.         0.29673575] [1.         0.25524303]\\n\",\n      \"12.575131449327872 [1.54827206 0.1353389 ] [1.45397267 0.13495303] [1.412522   0.09243483]\\n\",\n      \"10.047461271617449 [1.         0.13838947] [1.         0.12543906] [1.07980201 0.03699989]\\n\",\n      \"11.168585162085147 [1.40165996 0.36333975] [1.22556733 0.34348119] [1.03948977 0.30648717]\\n\",\n      \"6.3148142448845785 [1.05192976 0.21690696] [1.04212699 0.17886409] [1.00050678 0.14293561]\\n\",\n      \"4.213174439921227 [1.03662484 0.28966567] [1.00215127 0.26474702] [1.         0.21927834]\\n\",\n      \"9.350377743405199 [1.         0.35654214] [1.         0.32229306] [1.         0.27400208]\\n\",\n      \"7.092296103190132 [1.03188979 0.36834402] [1.05524235 0.32111044] [1.         0.29487285]\\n\",\n      \"5.155585443443663 [1.         0.26704342] [1.         0.24899008] [1.         0.21170052]\\n\",\n      \"8.949140629837553 [1.05108868 0.39110188] [1.         0.36651523] [1.         0.30836115]\\n\",\n      \"8.716553166271584 [1.         0.14921731] [1.        0.1269959] [1.        0.1051833]\\n\",\n      \"10.086518745680003 [1.69681287 0.04678019] [1.12508498 0.24327165] [1.         0.33185283]\\n\",\n      \"11.182964161467057 [1.63311972 0.42401394] [1.38102156 0.42183174] [1.22833146 0.39035724]\\n\",\n      \"13.76180563173197 [1.        0.3762183] [1.         0.34865231] [1.         0.30293159]\\n\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"4.590938221163645 [1.14637812 0.3602448 ] [1.10077618 0.32862303] [1.         0.30323663]\\n\",\n      \"10.918743442377107 [1.12710391 0.15037918] [1.10657724 0.14239141] [1.10051106 0.11869385]\\n\",\n      \"11.804579155662669 [1.33688069 0.43758953] [1.34865482 0.37730493] [1.3432065  0.28776081]\\n\",\n      \"13.177277116508062 [1.52878098 0.20696249] [1.28077    0.21824131] [1.30689721 0.08714808]\\n\",\n      \"6.787660550087845 [1.         0.39119141] [1.        0.3606803] [1.         0.29720819]\\n\",\n      \"4.451091387013824 [1.19332638 0.17125793] [1.16410742 0.16994217] [1.11273918 0.16180904]\\n\",\n      \"8.513095756903981 [1.04369299 0.27813237] [1.04389932 0.26202727] [1.03733742 0.22642191]\\n\",\n      \"7.386611194268827 [1.93391501 0.3312007 ] [2.002023   0.25603203] [1.590588   0.23676108]\\n\",\n      \"15.14940803546385 [1.         0.47213044] [1.11423475 0.38502384] [1.         0.36287811]\\n\",\n      \"16.421871987502197 [1.         0.37642323] [1.        0.3443427] [1.02078065 0.2809473 ]\\n\",\n      \"7.161802208489492 [1.14310057 0.22848877] [1.13348165 0.21734642] [1.11509131 0.18413408]\\n\",\n      \"7.6595161103083855 [1.         0.14029402] [1.         0.09303844] [1.         0.08251386]\\n\",\n      \"13.35693014887036 [1.48071015 0.19590308] [1.29772968 0.19501838] [1.30705279 0.08045421]\\n\",\n      \"10.803408667325941 [1.98737793 0.33374324] [1.75384086 0.30452603] [1.53861443 0.25186221]\\n\",\n      \"10.075008722465236 [1.08200726 0.13105404] [1.         0.16296988] [1.         0.13251569]\\n\",\n      \"9.195155565321723 [1.03749115 0.27621976] [1.02056792 0.27036357] [1.         0.25340891]\\n\",\n      \"7.43204157510038 [1.         0.43050409] [1.67734648 0.21263812] [1.35465377 0.19194443]\\n\",\n      \"5.896029385029159 [1.42870304 0.27649214] [1.2812687  0.26793401] [1.18041262 0.24718245]\\n\",\n      \"8.123436416015343 [1.        0.3134811] [1.        0.2695824] [1.         0.19264644]\\n\",\n      \"5.716475809488119 [1.50767532 0.08146423] [1.38561093 0.08507805] [1.26967229 0.0767172 ]\\n\",\n      \"9.033193837283932 [2.98305142 0.24181145] [2.39644536 0.25587678] [1.79498076 0.26201245]\\n\",\n      \"12.164324067161884 [1.        0.3947932] [1.         0.36898214] [1.         0.32283695]\\n\",\n      \"5.339787973130951 [1.         0.24729247] [1.         0.22994155] [1.         0.19348346]\\n\",\n      \"12.678490499063159 [1.64299616 0.04190122] [1.0591637  0.24535916] [1.         0.34403424]\\n\",\n      \"6.077708548856496 [1.14963091 0.17937724] [1.13040526 0.15018877] [1.08640538 0.11397255]\\n\",\n      \"12.178252021323798 [1.28310002 0.38013975] [1.11254179 0.36015164] [1.         0.30814744]\\n\",\n      \"7.309460080103556 [1.06520598 0.37547398] [1.11286913 0.30832585] [1.00431014 0.30221934]\\n\",\n      \"9.220776691475416 [1.         0.34339892] [1.         0.31220978] [1.         0.27417617]\\n\",\n      \"5.098050429640151 [1.         0.29674274] [1.         0.26830966] [1.         0.22684651]\\n\",\n      \"9.847574352492146 [1.         0.13295115] [1.         0.11941661] [1.07107621 0.03736894]\\n\",\n      \"9.378924425161026 [1.66026543 0.37723044] [1.60835751 0.32637675] [1.36573792 0.29443589]\\n\",\n      \"13.182248170063556 [1.         0.31282152] [1.         0.29206134] [1.         0.24996682]\\n\",\n      \"5.318000365818506 [1.        0.3143855] [1.         0.29229958] [1.         0.25367857]\\n\",\n      \"6.6286177536373945 [1.54499351 0.40292311] [1.70761229 0.35502976] [1.5882886  0.32408944]\\n\",\n      \"5.441334460519029 [1.         0.23373452] [1.         0.19965821] [1.         0.15861831]\\n\",\n      \"6.733051962029401 [1.13882056 0.115319  ] [1.12429207 0.11785498] [1.10701285 0.11406595]\\n\",\n      \"4.808337374478711 [1.47909169 0.35586786] [1.19549607 0.33952498] [1.05558185 0.28169271]\\n\",\n      \"7.18045223990615 [1.         0.35109998] [1.         0.31738902] [1.        0.2704406]\\n\",\n      \"9.139950160459536 [2.73194726 0.41418685] [2.11928726 0.42496597] [1.67686128 0.40544779]\\n\",\n      \"16.14149869210153 [1.         0.47548209] [1.04629857 0.42856518] [1.06596913 0.37772377]\\n\",\n      \"12.070641071298047 [1.         0.34527724] [1.         0.29705856] [1.        0.2276855]\\n\",\n      \"7.434200544866427 [1.32181479 0.11508353] [1.2032701  0.10868013] [1.24194895 0.03617664]\\n\",\n      \"6.870374797147275 [1.         0.25905316] [1.         0.22340423] [1.         0.15733055]\\n\",\n      \"11.709246462998976 [1.16809094 0.26245428] [1.10184073 0.23797386] [1.05045989 0.19024058]\\n\",\n      \"8.20325483957443 [1.         0.33986977] [1.         0.31646446] [1.         0.28104671]\\n\",\n      \"9.033009859575218 [1.17359599 0.37475019] [1.04596266 0.35637363] [1.         0.30693566]\\n\",\n      \"15.702183543583855 [1.         0.37089796] [1.         0.33939589] [1.         0.30634559]\\n\",\n      \"15.540954207408367 [1.         0.40327596] [1.         0.37952947] [1.         0.33862027]\\n\",\n      \"21.38126183828421 [1.28070914 0.29927744] [1.14955948 0.29638129] [1.39224783 0.13822175]\\n\",\n      \"22.436770425127463 [1.22642345 0.31261728] [1.09821051 0.30894634] [1.29405853 0.15898035]\\n\",\n      \"6.792444514498642 [1.44001498 0.4332393 ] [1.34888428 0.40130308] [1.26794739 0.34149732]\\n\",\n      \"8.968912513117655 [1.42961773 0.3099394 ] [1.23306083 0.30323857] [1.49129172 0.14860136]\\n\",\n      \"5.674805610382074 [1.         0.24244748] [ 1.44426089 -0.00670134] [1.25205041 0.00637623]\\n\",\n      \"9.20696195437281 [1.         0.34680365] [1.         0.32356847] [1.         0.28885825]\\n\",\n      \"12.01006668627144 [1.20058944 0.45253004] [1.23001275 0.3922238 ] [1.20167105 0.32995753]\\n\",\n      \"7.268419832489814 [1.22416277 0.33895806] [1.17951589 0.30673209] [1.         0.29074284]\\n\",\n      \"12.0522052503042 [1.11756825 0.26540019] [1.07952285 0.2386899 ] [1.03653212 0.19095129]\\n\",\n      \"10.475280985058973 [1.36359868 0.11570846] [1.30433452 0.10568103] [1.28807227 0.05140718]\\n\",\n      \"7.492840276860108 [1.02794703 0.22600696] [1.04697332 0.17881622] [1.09720094 0.09238529]\\n\",\n      \"13.681650161671499 [1.11399402 0.47833617] [1.18537375 0.41733039] [1.122122   0.37450852]\\n\",\n      \"13.424825944230744 [1.         0.34504583] [1.         0.29594668] [1.         0.22853987]\\n\",\n      \"4.780767547342776 [1.46095773 0.35687878] [1.19687247 0.33216925] [1.07251754 0.26640343]\\n\",\n      \"7.911775322280817 [2.63443407 0.43824381] [2.05924642 0.45010572] [1.66740395 0.42959416]\\n\",\n      \"4.015854860532171 [1.         0.20793752] [1.         0.17329306] [1.03386235 0.11967121]\\n\",\n      \"6.713520362470293 [1.         0.32240739] [1.         0.30235687] [1.         0.26612081]\\n\",\n      \"12.074188326332035 [1.         0.30551639] [1.      0.28382] [1.         0.24327509]\\n\",\n      \"9.454394719378218 [1.75232135 0.36525822] [1.73434966 0.30734527] [1.47637131 0.2781888 ]\\n\",\n      \"10.293786594978206 [1.         0.13087376] [1.         0.11884709] [1.0620732  0.04798402]\\n\",\n      \"7.5419608178326385 [1.05831586 0.37541851] [1.1376635  0.30449258] [1.01610063 0.29506392]\\n\",\n      \"6.520678105114464 [1.         0.36343357] [1.         0.33105139] [1.         0.28624847]\\n\",\n      \"3.978041257149275 [1.0435017  0.29055885] [1.00155223 0.26582389] [1.         0.22054885]\\n\",\n      \"6.415524536488719 [1.14623474 0.17867283] [1.12664102 0.14660054] [1.07920918 0.11174827]\\n\",\n      \"9.048678022567225 [1.1420504  0.40753175] [1.         0.39215311] [1.         0.31516691]\\n\",\n      \"12.858382628127798 [1.55014662 0.04664071] [1.        0.2662912] [1.         0.34158427]\\n\",\n      \"12.04659315276056 [1.         0.37716955] [1.         0.35372715] [1.         0.30968543]\\n\",\n      \"5.816629031044932 [1.51156517 0.08032432] [1.38840341 0.08455815] [1.2701874  0.07628197]\\n\",\n      \"5.4013041206763015 [1.4884411  0.25501198] [1.40327348 0.23347025] [1.29894998 0.21259459]\\n\",\n      \"7.914935200068231 [1.         0.34859634] [1.         0.29710985] [1.         0.20886223]\\n\",\n      \"8.947884472998847 [1.03145903 0.28212011] [1.01350009 0.27661156] [1.         0.25647537]\\n\",\n      \"7.219420391296332 [1.         0.44153257] [1.65049284 0.22948439] [1.37788939 0.19071623]\\n\",\n      \"15.518736994613539 [1.26404283 0.44324559] [1.34474669 0.37493212] [1.43185122 0.26851484]\\n\",\n      \"12.788760154664589 [1.42655951 0.21897641] [1.26131064 0.21251736] [1.29509424 0.0944886 ]\\n\",\n      \"6.177922115924289 [1.         0.13960248] [1.         0.12342618] [1.         0.10192676]\\n\",\n      \"8.04152307007944 [1.         0.14275038] [1.         0.09092113] [1.         0.08070935]\\n\",\n      \"7.215196988886108 [1.         0.24482581] [1.         0.20797928] [1.         0.15821295]\\n\",\n      \"8.09476838905765 [1.10566399 0.26031656] [1.09516876 0.24780982] [1.07277844 0.21927915]\\n\",\n      \"13.639815646703246 [1.38631652 0.34327928] [1.38682424 0.29590406] [1.2800876  0.23538911]\\n\",\n      \"14.468988791238358 [1.         0.44749568] [1.00451627 0.37178933] [1.         0.33439722]\\n\",\n      \"6.988720319213014 [1.         0.12226574] [1.         0.10294753] [1.         0.08100132]\\n\",\n      \"15.773147620581462 [1.29654139 0.37541121] [1.28474104 0.31503998] [1.33065061 0.20304062]\\n\",\n      \"10.751327276204009 [1.23926839 0.21041204] [1.22484533 0.1981505 ] [1.19218005 0.17204274]\\n\",\n      \"5.684842553233122 [1.03737285 0.12962043] [1.01477472 0.12172435] [1.00400735 0.099963  ]\\n\",\n      \"5.417121817818472 [1.         0.27240543] [1.        0.2325963] [1.         0.17525595]\\n\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"5.891694302225043 [1.25972777 0.16373711] [1.22746573 0.15912013] [1.17665212 0.14910726]\\n\",\n      \"8.793160183816639 [1.         0.39273616] [1.         0.34320608] [1.         0.27509644]\\n\",\n      \"5.716706882708992 [2.13114339 0.31695666] [1.91464865 0.2893754 ] [1.66717715 0.23251986]\\n\",\n      \"9.396427221145542 [1.         0.44386907] [1.         0.43682454] [1.         0.41753878]\\n\",\n      \"4.060973086245111 [1.12811625 0.07502233] [1.09422964 0.05856901] [1.06109923 0.04190814]\\n\",\n      \"4.4374517881175155 [1.07257936 0.30326746] [1.10072261 0.22900328] [1.12278412 0.12179577]\\n\",\n      \"11.133661078559285 [1.         0.46226972] [1.         0.41594019] [1.         0.34609965]\\n\",\n      \"6.150874953906226 [1.33361544 0.11011473] [1.24339714 0.10663268] [1.13506818 0.09751161]\\n\",\n      \"4.653826185082733 [1.         0.12597343] [1.         0.09853081] [1.        0.0768966]\\n\",\n      \"7.2510745336822495 [1.53460837 0.02320472] [1.13416   0.1553736] [1.         0.25563807]\\n\",\n      \"4.356431560576406 [1.29197913 0.06236888] [1.27391104 0.06281774] [1.25477735 0.05111601]\\n\",\n      \"7.112568932618397 [1.        0.3945393] [1.         0.35591552] [1.         0.30137758]\\n\",\n      \"18.164439475148683 [1.21523226 0.31266801] [1.180494   0.27199212] [1.05022813 0.23082529]\\n\",\n      \"7.742751882633314 [1.15416499 0.06784116] [1.10306645 0.06395505] [1.06683635 0.04797953]\\n\",\n      \"6.350897941419369 [1.0541251 0.2380309] [1.15428629 0.14187908] [1.11715909 0.09057371]\\n\",\n      \"14.767729380228179 [1.         0.37644373] [1.         0.33488604] [1.         0.26979247]\\n\",\n      \"6.385047435437507 [1.04806365 0.38712485] [1.         0.35686888] [1.         0.29697577]\\n\",\n      \"9.692609050090336 [1.40392011 0.20440658] [1.34051411 0.19120975] [1.28410855 0.1465873 ]\\n\",\n      \"7.024719751551154 [1.         0.32698809] [1.         0.30520781] [1.       0.267538]\\n\",\n      \"13.663091185062918 [1.         0.45325591] [1.        0.3664679] [1.05514643 0.33619966]\\n\",\n      \"11.107894150487391 [1.41292528 0.32187687] [1.170399   0.30418462] [1.         0.26528342]\\n\",\n      \"14.766242092442566 [1.         0.47659992] [1.26344653 0.37950707] [1.20590299 0.32832604]\\n\",\n      \"6.4131843635100925 [1.06094015 0.25942477] [1.08727164 0.23325947] [1.06346143 0.21081528]\\n\",\n      \"5.080159014798777 [1.         0.18962443] [1.        0.1509963] [1.         0.05423987]\\n\",\n      \"5.719033079084207 [2.02944065 0.33118679] [1.64265725 0.32581619] [1.3396576  0.30601925]\\n\",\n      \"4.591219990471863 [1.31803798 0.07871335] [1.26773763 0.0718669 ] [1.3020643  0.02877215]\\n\",\n      \"3.7862175374727203 [1.         0.23186105] [1.         0.19123061] [1.         0.14931234]\\n\",\n      \"8.039122145306324 [1.56302739 0.31515415] [1.25831753 0.29635916] [1.36207974 0.15380288]\\n\",\n      \"9.677334364040973 [1.18170124 0.31586464] [1.13851834 0.26817432] [1.04093963 0.21807475]\\n\",\n      \"2.7370408421682444 [1.13616131 0.2318686 ] [1.40806525 0.0115724 ] [1.22428605 0.02977322]\\n\",\n      \"9.29730192539997 [1.30774979 0.11942896] [1.16347956 0.14512899] [1.3173694  0.01187218]\\n\",\n      \"15.526524691289387 [1.        0.3310932] [1.        0.2844831] [1.         0.23075083]\\n\",\n      \"11.017672565326 [1.18686802 0.24602856] [1.14106091 0.21916931] [1.09163688 0.17232029]\\n\",\n      \"10.99085372091498 [1.         0.23665997] [1.         0.21189778] [1.         0.17769479]\\n\",\n      \"9.463678038950732 [1.4321039  0.23182994] [1.3824293  0.18283563] [1.17269764 0.15142691]\\n\",\n      \"8.244724159079153 [1.29948866 0.45833219] [1.24525979 0.41506679] [1.17430125 0.35533612]\\n\",\n      \"10.76970808108283 [1.        0.2926715] [1.         0.24471049] [1.       0.196557]\\n\",\n      \"8.453974950451215 [1.2738775  0.41625414] [1.12890283 0.39542715] [1.02955501 0.35860726]\\n\",\n      \"7.0707075590839095 [1.         0.25485953] [1.         0.22364087] [1.        0.1733184]\\n\",\n      \"6.593210405970022 [1.         0.28848014] [1.         0.25040597] [1.        0.1918077]\\n\",\n      \"9.602109747672733 [1.         0.28655783] [1.        0.2396139] [1.         0.18953113]\\n\",\n      \"9.903267217800787 [1.         0.41750385] [1.         0.38717919] [1.         0.33202579]\\n\",\n      \"8.456703731077353 [1.31691898 0.44820206] [1.2646766  0.40121542] [1.20800165 0.33430126]\\n\",\n      \"11.442953078440361 [1.04063187 0.21002153] [1.05936275 0.18194225] [1.04522708 0.15555158]\\n\",\n      \"12.677632319692128 [1.1498834 0.2550538] [1.0994028  0.22739616] [1.05755804 0.17862199]\\n\",\n      \"11.431287810767149 [1.42107401 0.22256108] [1.36944357 0.17256377] [1.17493751 0.1400153 ]\\n\",\n      \"15.47924038477749 [1.         0.33688982] [1.         0.29763816] [1.         0.25367715]\\n\",\n      \"5.809853441619306 [1.5210456  0.11331387] [1.57340061 0.07305487] [1.37659686 0.05108007]\\n\",\n      \"8.02832602785699 [1.38440827 0.11370136] [1.31684005 0.10414527] [1.32822502 0.03368428]\\n\",\n      \"2.1627993504342338 [1.23271853 0.19819153] [ 1.50047342 -0.00571418] [1.2730289  0.01018472]\\n\",\n      \"5.784735548158705 [1.23161683 0.09958643] [1.21975339 0.08252126] [1.36840433 0.01555117]\\n\",\n      \"6.867159456154496 [1.09601151 0.39852261] [1.08863841 0.33842392] [1.19910545 0.20261333]\\n\",\n      \"3.479654525635185 [1.         0.23182781] [1.         0.18850923] [1.        0.1469227]\\n\",\n      \"4.640527830169732 [1.11161203 0.17408629] [1.1134837  0.13638527] [1.13256368 0.05422848]\\n\",\n      \"5.743096677976866 [2.4811614 0.2889833] [2.00574636 0.29374786] [1.5641532  0.27851168]\\n\",\n      \"12.946199421719438 [1.         0.28075043] [1.         0.26039756] [1.         0.22096662]\\n\",\n      \"6.074956281392489 [1.02888253 0.26223522] [1.04211541 0.23956685] [1.01987354 0.21830608]\\n\",\n      \"13.434984280530074 [1.         0.45607302] [1.0592372  0.38333056] [1.         0.34673542]\\n\",\n      \"6.8208066924477215 [1.         0.31946752] [1.         0.29469548] [1.         0.25358076]\\n\",\n      \"9.793612636024221 [1.31029557 0.23113843] [1.25568947 0.21249858] [1.19808227 0.16464676]\\n\",\n      \"9.494193456839135 [1.         0.42120974] [1.         0.37293915] [1.         0.29732569]\\n\",\n      \"6.13118477880261 [1.01103742 0.2534833 ] [1.13576454 0.14824446] [1.10924289 0.09856326]\\n\",\n      \"15.819647246381082 [1.         0.38783356] [1.         0.33566908] [1.         0.26309627]\\n\",\n      \"9.600605129090999 [1.         0.41556355] [1.         0.35365931] [1.        0.2601237]\\n\",\n      \"6.3574879048298545 [1.         0.38990792] [1.         0.35223975] [1.        0.2970689]\\n\",\n      \"6.90601017156322 [1.15069262 0.06453205] [1.10300312 0.05805457] [1.06194101 0.04461918]\\n\",\n      \"4.661849733982536 [1.29651397 0.06215004] [1.272838   0.06425507] [1.25147502 0.05305137]\\n\",\n      \"8.761250791481844 [1.         0.14595345] [1.         0.13013114] [1.         0.11264322]\\n\",\n      \"7.9036448804751664 [1.53889132 0.02408134] [1.14229798 0.16372788] [1.         0.26051279]\\n\",\n      \"12.162078747056947 [1.         0.39060393] [1.         0.35973595] [1.         0.31356092]\\n\",\n      \"5.451856216340836 [1.36560124 0.1024369 ] [1.27343169 0.10506909] [1.18994043 0.09010483]\\n\",\n      \"10.840442889289944 [1.        0.4680975] [1.         0.42356851] [1.        0.3540948]\\n\",\n      \"4.892808805866077 [1.         0.31285284] [1.08786949 0.22196679] [1.16825773 0.09104704]\\n\",\n      \"3.91410393618042 [1.13000409 0.07581248] [1.0835281  0.06373103] [1.05629125 0.0436979 ]\\n\",\n      \"6.69791580075448 [1.54250332 0.39868517] [1.32257885 0.38329475] [1.09516635 0.35167432]\\n\",\n      \"10.534278235064187 [1.         0.45227187] [1.         0.44406083] [1.         0.41957123]\\n\",\n      \"4.808537337151308 [1.20321121 0.16529559] [1.17718295 0.16302051] [1.13213581 0.15200777]\\n\",\n      \"3.8779116434356427 [1.26273839 0.11194173] [1.29326685 0.06953592] [1.24823897 0.03262819]\\n\",\n      \"10.331828146365424 [1.         0.39101359] [1.13691177 0.29844184] [1.1598001  0.21833517]\\n\",\n      \"4.356549666351259 [1.03360225 0.41918878] [1.      0.40059] [1.         0.35745555]\\n\",\n      \"5.435808897380261 [1.         0.28500083] [1.         0.24256721] [1.00652154 0.18118521]\\n\",\n      \"4.530972308690726 [1.05142997 0.07730137] [1.01134688 0.09776861] [1.06021387 0.03916837]\\n\",\n      \"8.713184312626323 [1.38576613 0.18030878] [1.3542891 0.1728741] [1.30368522 0.14759379]\\n\",\n      \"4.982778419286587 [1.        0.1108381] [1.         0.09112781] [1.         0.06686228]\\n\",\n      \"13.732649395522882 [1.47967034 0.29660066] [1.42575629 0.25034483] [1.41245788 0.15706471]\\n\",\n      \"11.952034661936985 [1.19330595 0.21750693] [1.18816777 0.20535072] [1.17868428 0.17350153]\\n\",\n      \"12.868449341343661 [1.5981332  0.30274196] [1.53698373 0.25092222] [1.41320284 0.17702285]\\n\",\n      \"5.9065513403184 [1.         0.25653251] [1.         0.22001491] [1.         0.16433825]\\n\",\n      \"7.889022281908016 [1.         0.13622052] [1.         0.08505973] [1.         0.07799846]\\n\",\n      \"17.62149507636407 [1.  0.5] [1.        0.4868302] [1.         0.45386779]\\n\",\n      \"14.184563511527882 [1.67857751 0.3744754 ] [1.65420228 0.32485745] [1.60230479 0.25334728]\\n\",\n      \"8.23373358791664 [1.46165354 0.08280924] [1.42603347 0.0504182 ] [1.26485091 0.07282777]\\n\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"18.124118913101793 [1.32922321 0.23887569] [1.10528409 0.2573875 ] [1.22403664 0.10308437]\\n\",\n      \"3.9868874555079876 [1.05720778 0.40419764] [1.         0.38784377] [1.         0.34279549]\\n\",\n      \"4.599539602865495 [1.14186506 0.14652774] [1.15570267 0.10849091] [1.15647165 0.0610628 ]\\n\",\n      \"5.759627160656901 [1.70349963 0.23830525] [1.61395007 0.21029727] [1.48694521 0.19041081]\\n\",\n      \"6.54632418115673 [1.1821057  0.27387597] [1.23419344 0.208279  ] [1.2300459  0.11255374]\\n\",\n      \"4.006927159625802 [1.12399294 0.07605867] [1.08566762 0.05949197] [1.048149   0.04550453]\\n\",\n      \"5.155593872109015 [1.3669911  0.10548395] [1.2831182  0.10501239] [1.19686576 0.092191  ]\\n\",\n      \"12.297423091066193 [1.         0.38016573] [1.         0.35825509] [1.         0.31677196]\\n\",\n      \"11.112905763201406 [1.        0.4562093] [1.         0.40945472] [1.         0.33857818]\\n\",\n      \"4.780549932697073 [1.30110912 0.06118039] [1.27517121 0.06293304] [1.25995527 0.04937645]\\n\",\n      \"12.850462262114648 [1.        0.4017174] [1.         0.37097501] [1.         0.32171805]\\n\",\n      \"4.61288722765694 [1.         0.10093775] [1.         0.08630901] [1.         0.06161945]\\n\",\n      \"6.096782451635316 [1.54175764 0.32680338] [1.36718423 0.30584412] [1.17793683 0.27107083]\\n\",\n      \"5.131156121131035 [1.17682994 0.22998307] [1.05034279 0.23655693] [1.         0.21443303]\\n\",\n      \"7.879880527969907 [1.14493467 0.07164689] [1.11483468 0.06058582] [1.07292721 0.04559548]\\n\",\n      \"9.387342953990725 [1.        0.3696372] [1.09627522 0.27817551] [1.05165484 0.26282338]\\n\",\n      \"4.918880753691131 [1.45908293 0.29313014] [1.3782307  0.27507742] [1.23199078 0.24633128]\\n\",\n      \"14.158239224299484 [1.         0.36476705] [1.         0.32522647] [1.         0.26000801]\\n\",\n      \"9.517298273968473 [1.         0.42602333] [1.         0.37527014] [1.         0.29698766]\\n\",\n      \"8.231845283312623 [1.18861333 0.29332819] [1.18790864 0.24573307] [1.13856228 0.17827134]\\n\",\n      \"9.53983709142983 [1.13799697 0.38524838] [1.         0.36275799] [1.         0.28309355]\\n\",\n      \"6.666709857548999 [1.         0.31742595] [1.         0.29604461] [1.         0.25653994]\\n\",\n      \"6.270934162557976 [1.07343917 0.25619321] [1.08890573 0.23226074] [1.06953595 0.20817377]\\n\",\n      \"14.344944542601246 [1.        0.4773223] [1.03664224 0.42563881] [1.0325325  0.37282668]\\n\",\n      \"13.232647434100098 [1.         0.28509013] [1.         0.26535221] [1.         0.22192399]\\n\",\n      \"5.415170282276734 [1.03703071 0.20910594] [1.03781338 0.17703365] [1.00296583 0.15074967]\\n\",\n      \"7.195833600123861 [1.         0.41679496] [1.         0.35882238] [1.01981068 0.26818582]\\n\",\n      \"5.835186775207567 [2.44922973 0.29453994] [1.95348276 0.30152249] [1.48030465 0.29045715]\\n\",\n      \"3.914966619336214 [1.         0.20027453] [1.         0.15885312] [1.         0.05890127]\\n\",\n      \"13.213401229836093 [1.         0.34283422] [1.         0.29034656] [1.         0.22714111]\\n\",\n      \"2.9317032881679026 [1.10880727 0.25087744] [1.41523591 0.01255218] [1.22426241 0.03113065]\\n\",\n      \"8.624558853353015 [1.18565813 0.2978048 ] [1.13924446 0.25165697] [1.04619305 0.20361098]\\n\",\n      \"14.12549007603379 [1.         0.33063251] [1.         0.28683951] [1.        0.2317352]\\n\",\n      \"3.4455625833028116 [1.73370339 0.05743556] [1.81161211 0.01492903] [1.28283606 0.05056141]\\n\",\n      \"9.25790379122618 [1.27904293 0.24981154] [1.24774003 0.19116945] [1.13032598 0.15074178]\\n\",\n      \"10.835337161227628 [1.         0.24161495] [1.         0.21326125] [1.         0.17767127]\\n\",\n      \"7.911179455599862 [1.16541014 0.44322421] [1.01999164 0.43198894] [1.         0.37928913]\\n\",\n      \"10.342808882158653 [1.05037901 0.24864016] [1.07215549 0.19487928] [1.07843212 0.14649988]\\n\",\n      \"8.056915768772097 [1.30619205 0.45765377] [1.24991115 0.415563  ] [1.2128384  0.34760788]\\n\",\n      \"14.85465376848101 [1.         0.40002233] [1.         0.37138823] [1.         0.31310465]\\n\",\n      \"19.511684256662633 [1.26585356 0.31415552] [1.11639465 0.30895147] [1.21387797 0.1743843 ]\\n\",\n      \"19.905318196850637 [1.27049544 0.30873869] [1.12271798 0.30548697] [1.22200986 0.1718376 ]\\n\",\n      \"7.046293011243918 [1.        0.3732745] [1.         0.33097561] [1.         0.25958042]\\n\",\n      \"19.706152544196275 [1.         0.39955553] [1.         0.35989177] [1.         0.32481473]\\n\",\n      \"8.271463177472285 [1.43793467 0.42249578] [1.32703756 0.39426178] [1.31412245 0.32237713]\\n\",\n      \"7.218605713310288 [1.0415818  0.47150815] [1.         0.43114047] [1.         0.36648088]\\n\",\n      \"12.62778912841859 [1.         0.27095295] [1.01846325 0.21648754] [1.03289859 0.16241479]\\n\",\n      \"9.88270569624501 [1.2978252  0.24223272] [1.26559263 0.18620396] [1.11782403 0.15160372]\\n\",\n      \"1.5636127976922047 [1.         0.28248478] [1.         0.22363141] [1.         0.09104517]\\n\",\n      \"13.390352111801858 [1.         0.34062208] [1.         0.29697274] [1.         0.24358278]\\n\",\n      \"8.261674494393619 [1.25347048 0.27279468] [1.19417282 0.22676013] [1.11356917 0.17571076]\\n\",\n      \"13.234432716213522 [1.         0.34471039] [1.         0.29848327] [1.         0.23258223]\\n\",\n      \"3.169614729182636 [1.10947422 0.26262313] [1.47078038 0.01148745] [1.26961127 0.02460483]\\n\",\n      \"5.852995628206429 [2.48916533 0.28730657] [1.95308589 0.29949853] [1.47909348 0.28945001]\\n\",\n      \"4.131430687753246 [1.34952521 0.12119938] [1.33574652 0.09478586] [1.10795334 0.06341004]\\n\",\n      \"6.340055497747631 [1.         0.42067454] [1.00844278 0.35966451] [1.05388587 0.25557233]\\n\",\n      \"5.506406061399939 [1.05264599 0.21723827] [1.03226715 0.1870196 ] [1.00498817 0.15501192]\\n\",\n      \"4.89467564828471 [1.1540556  0.12693867] [1.15079474 0.1005251 ] [1.0664712  0.04490176]\\n\",\n      \"8.698451497948541 [1.         0.36944759] [1.         0.33819784] [1.         0.29021945]\\n\",\n      \"5.8829784341171125 [1.00572393 0.28395658] [1.02614286 0.25901765] [1.01452655 0.23212322]\\n\",\n      \"12.977928439478653 [1.         0.28127933] [1.         0.26068327] [1.        0.2213442]\\n\",\n      \"14.755100213631243 [1.         0.46201402] [1.19924146 0.37012335] [1.11153423 0.33087121]\\n\",\n      \"10.064144269856458 [1.14832169 0.37738287] [1.         0.35893565] [1.         0.28248572]\\n\",\n      \"7.6368844832548515 [1.         0.31132055] [1.         0.28578546] [1.         0.24287838]\\n\",\n      \"6.687217298797576 [1.15724037 0.28676805] [1.17081934 0.23879038] [1.13494045 0.17344283]\\n\",\n      \"9.58202188153881 [1.         0.42426239] [1.         0.37442876] [1.         0.29812282]\\n\",\n      \"4.868615337808833 [1.15919897 0.2433856 ] [1.07631749 0.23833285] [1.         0.22388433]\\n\",\n      \"11.655515957723736 [1.         0.37358036] [1.         0.33007272] [1.         0.26131269]\\n\",\n      \"9.596588977151242 [1.         0.36608441] [1.08511938 0.27830716] [1.00992861 0.26914553]\\n\",\n      \"6.034714182004036 [1.58973187 0.31740757] [1.43175713 0.29578856] [1.23020381 0.26008328]\\n\",\n      \"12.564334293206327 [1.         0.39437194] [1.         0.36352051] [1.         0.31566445]\\n\",\n      \"5.889563098969405 [1.         0.14375823] [1.         0.11739547] [1.         0.09522836]\\n\",\n      \"4.570063428873963 [1.30444438 0.06058142] [1.27588618 0.06371982] [1.26074365 0.0505582 ]\\n\",\n      \"12.14829320880722 [1.         0.38080579] [1.         0.35939192] [1.         0.31904503]\\n\",\n      \"9.939093277977596 [1.         0.45869984] [1.         0.41444018] [1.         0.34632318]\\n\",\n      \"5.345084527136086 [1.36068286 0.09987179] [1.28112887 0.09920155] [1.19092264 0.08733046]\\n\",\n      \"3.834859218594147 [1.06751783 0.10027921] [1.07723675 0.0691018 ] [1.03444079 0.05330765]\\n\",\n      \"5.570271796656015 [1.         0.37750799] [1.         0.34380663] [1.         0.28871752]\\n\",\n      \"6.632225582647623 [1.22723518 0.25349527] [1.20337442 0.22529021] [1.15206338 0.16473358]\\n\",\n      \"3.7245353268928985 [1.151156   0.38800718] [1.02595136 0.38470179] [1.         0.33839395]\\n\",\n      \"3.9682573428217274 [1.18939777 0.12761237] [1.21977333 0.08557312] [1.19246896 0.04557403]\\n\",\n      \"20.755626834886233 [1.  0.5] [1.         0.49153093] [1.         0.45250084]\\n\",\n      \"4.607417086387852 [1.04548425 0.1261651 ] [1.03005303 0.11349889] [1.00131411 0.10021228]\\n\",\n      \"16.453616663613925 [1.22269744 0.26910482] [1.03089546 0.28080068] [1.16097286 0.12425261]\\n\",\n      \"7.748004782649326 [1.40331675 0.11321791] [1.39996438 0.07066028] [1.30206376 0.06501894]\\n\",\n      \"6.385436131632363 [1.70249006 0.36445203] [1.52836194 0.33543283] [1.25467478 0.29441277]\\n\",\n      \"15.397874222942347 [1.4608715 0.4252542] [1.44220057 0.37527856] [1.36308945 0.30897187]\\n\",\n      \"6.043950811136515 [1.         0.11375008] [1.         0.08301911] [1.         0.06138666]\\n\",\n      \"7.443403802940812 [1.         0.25151869] [1.         0.21670456] [1.         0.16237954]\\n\",\n      \"13.17500086957239 [1.44958733 0.33520476] [1.39339071 0.29233331] [1.24992011 0.23805035]\\n\",\n      \"10.79201123984015 [1.32974063 0.16457546] [1.30392758 0.16093623] [1.27719686 0.13503117]\\n\",\n      \"5.358439952320776 [1.         0.28845737] [1.         0.24758458] [1.         0.18727148]\\n\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"4.839713107103963 [1.         0.12382349] [1.00983414 0.06859447] [1.0173457  0.06087906]\\n\",\n      \"13.783040869059848 [1.         0.43234353] [1.         0.40366689] [1.         0.36655522]\\n\",\n      \"11.355133210366915 [1.17375473 0.23277929] [1.1698897  0.21643841] [1.14696259 0.18725392]\\n\",\n      \"13.629910266670032 [1.52758876 0.29331789] [1.47490035 0.24233049] [1.43440128 0.15101406]\\n\",\n      \"6.168596795428792 [1.0998412  0.29026187] [1.13416883 0.2196127 ] [1.05986944 0.14377691]\\n\",\n      \"4.288050436543725 [1.19741602 0.04366837] [1.1038339  0.05014255] [1.0634177  0.03838106]\\n\",\n      \"5.836821811911606 [1.44333632 0.39423676] [1.20655943 0.37783335] [1.07501377 0.33167891]\\n\",\n      \"5.526374437104748 [1.44935734 0.09082052] [1.42332044 0.05413665] [1.35768648 0.03412326]\\n\",\n      \"19.879032561148527 [1.00143127 0.5       ] [1.         0.49989319] [1.         0.46461947]\\n\",\n      \"5.753800656087086 [1.         0.19888411] [1.         0.17527174] [1.         0.14177691]\\n\",\n      \"7.538761930151618 [1.         0.36227097] [1.         0.33502637] [1.         0.29019684]\\n\",\n      \"4.377269401419613 [1.27585858 0.05518306] [1.25798571 0.05445539] [1.22536991 0.04757483]\\n\",\n      \"14.113951039917655 [1.         0.39644585] [1.         0.36892543] [1.        0.3224152]\\n\",\n      \"6.050573340274629 [1.37935154 0.09971442] [1.27961532 0.09755718] [1.18644881 0.08388369]\\n\",\n      \"10.24287426790233 [1.         0.48176737] [1.         0.43454273] [1.         0.36321288]\\n\",\n      \"9.854976871397522 [1.         0.43564793] [1.         0.37695982] [1.         0.29500269]\\n\",\n      \"9.209505123130837 [1.10139433 0.31352503] [1.13433612 0.27430368] [1.06656546 0.21958064]\\n\",\n      \"5.41147311417093 [1.30676009 0.34040911] [1.13732792 0.32128211] [1.         0.28302368]\\n\",\n      \"6.800057809730648 [1.         0.39462282] [1.         0.36103037] [1.         0.31067217]\\n\",\n      \"15.101050513955645 [1.        0.3653043] [1.        0.3153946] [1.00096094 0.24186618]\\n\",\n      \"7.873072404582189 [1.14277229 0.07795869] [1.10102994 0.07166576] [1.07135518 0.05263471]\\n\",\n      \"14.457991370510385 [1.         0.40917791] [1.         0.38303446] [1.         0.34620609]\\n\",\n      \"4.842251137901636 [1.00934833 0.31111778] [1.         0.29347213] [1.         0.25553139]\\n\",\n      \"6.367801696513165 [1.         0.28679118] [1.         0.25467389] [1.         0.20656831]\\n\",\n      \"6.483138962526544 [1.13785268 0.26727746] [1.14264027 0.24596525] [1.08479983 0.2304965 ]\\n\",\n      \"5.579665831228779 [1.22849478 0.24625582] [1.16458639 0.2330797 ] [1.06542959 0.21789782]\\n\",\n      \"10.466927565218658 [1.         0.37320549] [1.         0.32605051] [1.         0.25532877]\\n\",\n      \"9.093745011148197 [1.         0.36086579] [1.         0.33108925] [1.         0.28392246]\\n\",\n      \"15.3284043988048 [1.         0.36927274] [1.         0.34175819] [1.34924593 0.1669367 ]\\n\",\n      \"3.0393841381302527 [1.25619506 0.18919224] [ 1.48827333e+00 -2.32351488e-04] [1.26456789 0.02115308]\\n\",\n      \"9.138319578529781 [1.36479105 0.24590478] [1.28245287 0.20958615] [1.14847979 0.16865384]\\n\",\n      \"6.488615015580878 [1.12768567 0.12468495] [1.14242763 0.09519091] [1.16081641 0.03404052]\\n\",\n      \"3.689592393714326 [1.         0.23685664] [1.         0.19886097] [1.         0.15789694]\\n\",\n      \"10.703611767435968 [1.         0.38580316] [1.         0.33148521] [1.         0.26327223]\\n\",\n      \"11.057495178556948 [1.         0.40839311] [1.         0.36271256] [1.         0.28938717]\\n\",\n      \"11.257050659453116 [2.81169895 0.39462628] [2.17284109 0.41266398] [1.67387216 0.40595294]\\n\",\n      \"5.377969402147548 [2.42672982 0.33523331] [1.88879513 0.33183145] [1.38520387 0.33660904]\\n\",\n      \"10.583254643386892 [1.25119472 0.24808459] [1.22892184 0.19036331] [1.12883598 0.14981773]\\n\",\n      \"16.983683328439167 [1.        0.3419402] [1.         0.30393983] [1.         0.25605751]\\n\",\n      \"5.050869203557615 [1.29198231 0.15229284] [1.20352486 0.1392287 ] [1.24414849 0.04265336]\\n\",\n      \"7.548112735771338 [1.         0.26548097] [1.         0.22705054] [1.         0.16898619]\\n\",\n      \"5.9141866300442905 [1.22341841 0.11054603] [1.13235639 0.10897532] [1.27037594 0.01602645]\\n\",\n      \"6.435282354321136 [1.13161061 0.26900905] [1.11807883 0.25184886] [1.08303097 0.22892779]\\n\",\n      \"13.609323509609263 [1.         0.32877443] [1.         0.30341652] [1.         0.20613864]\\n\",\n      \"15.184223641531718 [1.         0.40143023] [1.         0.36939884] [1.         0.31510046]\\n\",\n      \"13.913675186771753 [1.        0.3426041] [1.         0.29922944] [1.        0.2458466]\\n\",\n      \"8.477446212219109 [1.61136307 0.37478658] [1.48424498 0.34918918] [1.59215381 0.25505025]\\n\",\n      \"9.48077344085526 [1.79955938 0.33832709] [1.63219262 0.31941699] [1.67976814 0.24321916]\\n\",\n      \"12.133686880611858 [1.         0.33516168] [1.         0.32318787] [1.02600621 0.20780843]\\n\",\n      \"6.431276338217847 [1.14769126 0.26820245] [1.13731583 0.2497082 ] [1.10222323 0.2258782 ]\\n\",\n      \"14.602968511288568 [1.         0.26648705] [1.         0.22446566] [1.         0.17597921]\\n\",\n      \"15.924964263030803 [1.         0.35419677] [1.         0.30666105] [1.         0.24987524]\\n\",\n      \"10.073897762654642 [1.         0.48581412] [1.         0.43717362] [1.         0.37610968]\\n\",\n      \"12.504329670174597 [1.         0.41409509] [1.         0.38374635] [1.         0.33433319]\\n\",\n      \"6.148500583702118 [1.23945051 0.09895355] [1.13094388 0.10288653] [1.18233789 0.02107214]\\n\",\n      \"7.220512159715028 [1.         0.25923894] [1.         0.22030217] [1.         0.16259607]\\n\",\n      \"5.519340769919593 [1.28009163 0.16776537] [1.2402446  0.13756742] [1.2321024  0.05012039]\\n\",\n      \"13.64398389396302 [1.         0.36034972] [1.         0.31352092] [1.         0.26309365]\\n\",\n      \"10.634962837733285 [1.35753494 0.22295601] [1.29589863 0.17297748] [1.18312851 0.13153996]\\n\",\n      \"9.454209269509137 [1.         0.44034392] [1.         0.38051256] [1.         0.29502614]\\n\",\n      \"20.443099022794467 [1.  0.5] [1.  0.5] [1.         0.47353101]\\n\",\n      \"5.223012826020312 [1.16866365 0.1297681 ] [1.155908   0.10422172] [1.19055712 0.02997721]\\n\",\n      \"12.776858241921095 [1.         0.36682438] [1.         0.31388614] [1.         0.24573757]\\n\",\n      \"3.3725971074362833 [1.         0.23624836] [1.         0.19662779] [1.         0.15500436]\\n\",\n      \"8.96699163099819 [1.4497885 0.2409892] [1.30910507 0.20188193] [1.16561168 0.16120304]\\n\",\n      \"8.932694276036234 [1.         0.36633319] [1.        0.3359799] [1.       0.288562]\\n\",\n      \"16.184842224718512 [1.         0.34924031] [1.         0.32232702] [1.39812381 0.14642405]\\n\",\n      \"5.659516528473731 [1.0277043  0.31353489] [1.         0.30011375] [1.         0.26279099]\\n\",\n      \"8.572522665084962 [1.29433658 0.34716409] [1.12911407 0.32252244] [1.         0.27356586]\\n\",\n      \"5.353650096895277 [1.         0.24643639] [1.         0.21292349] [1.         0.15890878]\\n\",\n      \"5.214421749586937 [1.03779245 0.30160609] [1.00041772 0.28857636] [1.         0.25020107]\\n\",\n      \"6.325109529375439 [1.08280791 0.28250599] [1.10108639 0.2571972 ] [1.04739482 0.24093943]\\n\",\n      \"6.8436534333722925 [1.         0.38925963] [1.         0.35722972] [1.         0.30763456]\\n\",\n      \"6.723067652469975 [1.10803599 0.07018331] [1.07515174 0.06293138] [1.04363398 0.05082909]\\n\",\n      \"11.709164432486157 [1.         0.43642546] [1.         0.40740024] [1.         0.36830802]\\n\",\n      \"5.3570439833478485 [1.12087528 0.36083822] [1.      0.34111] [1.         0.28348617]\\n\",\n      \"10.311399780442231 [1.         0.43636906] [1.         0.37749159] [1.         0.29517295]\\n\",\n      \"9.472151547175129 [1.01455281 0.48212093] [1.         0.44170811] [1.         0.37051675]\\n\",\n      \"5.8702431454884945 [1.32963186 0.11462892] [1.24828534 0.10864314] [1.13822955 0.09836079]\\n\",\n      \"6.131922164901829 [1.         0.15196716] [1.         0.12767654] [1.         0.10427294]\\n\",\n      \"12.488894439856965 [1.         0.37644509] [1.         0.35530003] [1.         0.31405518]\\n\",\n      \"3.4997343302104658 [1.30354141 0.04477158] [1.27031851 0.04788207] [1.24386786 0.03928125]\\n\",\n      \"4.710214631785459 [1.15355498 0.12136588] [1.17552416 0.08284649] [1.15588257 0.04804761]\\n\",\n      \"6.3508030133773605 [1.         0.36685683] [1.         0.33957127] [1.         0.29943642]\\n\",\n      \"5.643347599813622 [1.47045292 0.0866719 ] [1.44325934 0.0513974 ] [1.35141321 0.04022204]\\n\",\n      \"5.577196759912198 [2.11268657 0.30958052] [1.92493724 0.27673218] [1.66359874 0.20946944]\\n\",\n      \"10.15041960848512 [1.         0.45245338] [1.        0.4455671] [1.         0.42473923]\\n\",\n      \"6.547922807246284 [1.12589041 0.08070766] [1.05853416 0.0737925 ] [1.0288873  0.06116788]\\n\",\n      \"4.820737038419537 [1.         0.32397085] [1.08401702 0.23804361] [1.15143934 0.11177209]\\n\",\n      \"12.566550715047965 [1.79690964 0.25083387] [1.69366263 0.20882101] [1.53867345 0.13593018]\\n\",\n      \"13.701983978967691 [1.         0.42134975] [1.         0.39213028] [1.         0.35769849]\\n\",\n      \"5.04770263943654 [1.         0.13258621] [1.         0.11076661] [1.         0.08275563]\\n\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"5.581341597426848 [1.        0.2900611] [1.         0.24783373] [1.         0.18492683]\\n\",\n      \"5.311604781287308 [1.08782496 0.09834222] [1.06928917 0.0586435 ] [1.10309036 0.04706468]\\n\",\n      \"11.947792958577935 [1.18489773 0.22567682] [1.17241542 0.21127795] [1.14331286 0.18418474]\\n\",\n      \"7.245979577607476 [1.         0.27493374] [1.         0.23770554] [1.         0.18181668]\\n\",\n      \"9.495213434127194 [1.44451906 0.16169426] [1.39969108 0.15862781] [1.33265799 0.13646104]\\n\",\n      \"11.979729601337414 [1.         0.41507488] [1.         0.38521274] [1.         0.34710955]\\n\",\n      \"5.2193883889384045 [1.         0.39587723] [1.         0.37685163] [1.         0.33322685]\\n\",\n      \"5.744075709511627 [1.02852747 0.23553349] [1.02116077 0.21611541] [1.00459175 0.19130408]\\n\",\n      \"5.23308406819318 [1.4429346  0.39336271] [1.20736696 0.37503816] [1.07921141 0.32431613]\\n\",\n      \"4.722516315006456 [1.3978819  0.08579017] [1.37456425 0.04520417] [1.2666831  0.04766728]\\n\",\n      \"13.911817208795442 [1.         0.39398167] [1.         0.36708687] [1.         0.32054315]\\n\",\n      \"6.461381801065701 [1.         0.16969502] [1.         0.14792357] [1.         0.12572392]\\n\",\n      \"5.960267461947247 [1.26982858 0.05557321] [1.24874686 0.05688101] [1.22589303 0.05044706]\\n\",\n      \"10.25300108948741 [1.         0.46187059] [1.         0.41989989] [1.         0.35074268]\\n\",\n      \"11.787179414869579 [1.76632527 0.41869485] [1.43269394 0.4198818 ] [1.24296689 0.3933356 ]\\n\",\n      \"5.958413956790485 [1.30835065 0.11653857] [1.22474741 0.11378343] [1.16048522 0.09304549]\\n\",\n      \"10.221917231289439 [1.         0.42500705] [1.         0.38027495] [1.         0.30968143]\\n\",\n      \"10.578450102041375 [1.25961277 0.25261408] [1.22239005 0.2208682 ] [1.19269214 0.14752483]\\n\",\n      \"11.968923765333345 [1.         0.42743135] [1.         0.40085837] [1.         0.36184916]\\n\",\n      \"6.499917748925294 [1.09647939 0.06982063] [1.05527086 0.06521202] [1.02322735 0.05498115]\\n\",\n      \"6.085882935694625 [1.         0.39960601] [1.         0.36280198] [1.        0.3084111]\\n\",\n      \"5.173288811059643 [1.11839091 0.36623351] [1.00609892 0.35280735] [1.         0.29913639]\\n\",\n      \"5.5329798066873535 [1.         0.28536787] [1.         0.26654426] [1.         0.23642846]\\n\",\n      \"5.4451650886986 [1.         0.22766355] [1.        0.1988036] [1.         0.14749155]\\n\",\n      \"14.958895497278894 [1.         0.36682107] [1.       0.338425] [1.43888911 0.14979942]\\n\",\n      \"8.940743720672465 [1.20547298 0.35744166] [1.07560059 0.33012045] [1.         0.27251407]\\n\",\n      \"7.436545361874667 [1.        0.3183971] [1.         0.29659595] [1.        0.2602954]\\n\",\n      \"9.134181236204086 [1.20762625 0.28111081] [1.14802603 0.2339364 ] [1.05651465 0.18709927]\\n\",\n      \"5.731991104797348 [1.41909933 0.44318859] [1.21217277 0.42579532] [1.06809325 0.39116092]\\n\",\n      \"5.206514265179677 [1.27050551 0.06583924] [1.25125421 0.06569998] [1.23254293 0.05498154]\\n\",\n      \"9.104566721644789 [1.         0.23076818] [1.         0.20671271] [1.        0.1702677]\\n\",\n      \"9.503861406452184 [1.        0.3933715] [1.         0.33789332] [1.        0.2698494]\\n\",\n      \"13.423185718655846 [1.         0.41264971] [1.         0.38841542] [1.         0.35286958]\\n\",\n      \"5.273742240055194 [1.31152661 0.08900905] [1.30530183 0.06828644] [1.29780875 0.02460142]\\n\",\n      \"9.38267066376904 [1.28749799 0.24714228] [1.2401462  0.20082994] [1.13681382 0.1539198 ]\\n\",\n      \"5.554657782269932 [1.00974729 0.1073955 ] [1.         0.09572028] [1.         0.07008646]\\n\",\n      \"5.3551988638892825 [1.03802083 0.23392611] [1.04737131 0.19578504] [1.08963842 0.11075911]\\n\",\n      \"5.078271939224496 [1.31380106 0.15165493] [1.22164349 0.14006279] [1.25888297 0.04407429]\\n\",\n      \"12.209527653607935 [1.12610795 0.24827783] [1.09141976 0.22121143] [1.05152291 0.17796183]\\n\",\n      \"7.236205279556641 [1.         0.26082362] [1.        0.2221064] [1.         0.16596905]\\n\",\n      \"6.329694670513371 [1.03281315 0.15687788] [1.05061858 0.13566962] [1.28249508 0.01469789]\\n\",\n      \"3.4062499660108245 [1.05565573 0.29287467] [1.4782522  0.01282449] [1.27482247 0.0235219 ]\\n\",\n      \"13.728436212527166 [1.         0.34172437] [1.        0.3249018] [1.0515369  0.20364606]\\n\",\n      \"7.910498525300672 [1.31618339 0.39247978] [1.09608015 0.38332792] [1.         0.34583144]\\n\",\n      \"12.309733990060018 [1.02936961 0.23070053] [1.05595266 0.17433231] [1.06823147 0.12203047]\\n\",\n      \"11.682042964572211 [1.         0.34170111] [1.         0.32074289] [1.35200494 0.14586391]\\n\",\n      \"9.349333543933142 [1.12337531 0.45501582] [1.02980351 0.42866449] [1.         0.37751805]\\n\",\n      \"12.479948608243525 [1.         0.33288736] [1.         0.30745312] [1.15279318 0.17813779]\\n\",\n      \"12.218947880890472 [1.         0.26605717] [1.         0.22768542] [1.         0.18603149]\\n\",\n      \"3.2464472905915756 [1.11891244 0.25752831] [1.44933499 0.00799026] [1.25140396 0.02325195]\\n\",\n      \"13.56354539998666 [1.         0.34602254] [1.         0.32759299] [1.         0.22130759]\\n\",\n      \"7.444098522183007 [1.67172634 0.37452495] [1.53363559 0.34770609] [1.50353595 0.28038547]\\n\",\n      \"6.278158304990354 [1.         0.17046653] [1.         0.15237823] [1.18718045 0.03143365]\\n\",\n      \"6.849653895905821 [1.         0.26536084] [1.        0.2303866] [1.         0.17477424]\\n\",\n      \"14.1965028140593 [1.         0.34648494] [1.        0.3021621] [1.         0.25534017]\\n\",\n      \"6.92101570965483 [1.         0.24208607] [1.        0.2148202] [1.        0.1584444]\\n\",\n      \"5.601546672059479 [1.         0.10607985] [1.         0.09256911] [1.         0.06920562]\\n\",\n      \"4.801359607137906 [1.30545915 0.15938466] [1.25225536 0.13460016] [1.26147593 0.04548683]\\n\",\n      \"11.62181801747974 [1.27719147 0.25835791] [1.26022512 0.20337497] [1.11133752 0.16349657]\\n\",\n      \"9.689429637875344 [1.03763838 0.3886762 ] [1.         0.34108572] [1.        0.2735025]\\n\",\n      \"5.014053782790492 [1.27711913 0.06474456] [1.26184626 0.06451575] [1.23906064 0.05258765]\\n\",\n      \"13.132059954438168 [1.         0.43430573] [1.         0.40595153] [1.         0.36851516]\\n\",\n      \"4.851125703567041 [1.84975898 0.3715327 ] [1.49529655 0.36010651] [1.22477509 0.33813967]\\n\",\n      \"9.956250610421554 [1.32748806 0.26442972] [1.20954762 0.22638284] [1.05842013 0.18972914]\\n\",\n      \"9.843111602971783 [1.21746058 0.36294867] [1.05752727 0.33936087] [1.         0.27209566]\\n\",\n      \"5.625140801016169 [1.18136395 0.254585  ] [1.1235811  0.23946801] [1.04228104 0.22170254]\\n\",\n      \"7.339319468599995 [1.02401503 0.31604764] [1.05925218 0.28613123] [1.03930775 0.26056254]\\n\",\n      \"5.863058274090753 [1.         0.25325056] [1.         0.22913271] [1.         0.18304533]\\n\",\n      \"5.563912449697366 [1.15424635 0.35983367] [1.00025694 0.34335127] [1.         0.28503896]\\n\",\n      \"10.411823162742717 [1.         0.43464808] [1.         0.40459801] [1.        0.3645003]\\n\",\n      \"6.024901040524667 [1.10853677 0.06224004] [1.06573987 0.05866889] [1.03424444 0.04801264]\\n\",\n      \"5.881619218655779 [1.         0.37912158] [1.         0.33593401] [1.         0.27946323]\\n\",\n      \"9.201288617140614 [1.13673983 0.28871202] [1.12352696 0.25245452] [1.07993825 0.19068865]\\n\",\n      \"10.799481336264586 [1.         0.41547583] [1.         0.36751483] [1.         0.29410016]\\n\",\n      \"10.975704498754395 [2.13255797 0.35624651] [1.75726357 0.3653143 ] [1.49769767 0.35076586]\\n\",\n      \"11.108866793622619 [1.         0.46047068] [1.        0.4200483] [1.         0.35238113]\\n\",\n      \"5.362009213568577 [1.28449891 0.05576823] [1.26617913 0.05598244] [1.23530509 0.05178904]\\n\",\n      \"13.60001791724086 [1.         0.39290307] [1.         0.36491517] [1.         0.31803772]\\n\",\n      \"5.824120077981131 [1.        0.1691489] [1.         0.14150499] [1.         0.11773266]\\n\",\n      \"5.227446595838796 [1.41883214 0.08755289] [1.39051572 0.04165076] [1.25145427 0.0517093 ]\\n\",\n      \"5.61139609746584 [1.4652467  0.38710518] [1.22706474 0.36978388] [1.09243372 0.31964533]\\n\",\n      \"4.171756045879388 [1.29394775 0.35944299] [1.07465225 0.36804221] [1.         0.34250625]\\n\",\n      \"4.395064708870857 [1.11969056 0.14840197] [1.13031667 0.1135897 ] [1.12875825 0.06807764]\\n\",\n      \"7.48165796112809 [1.20592439 0.28317858] [1.20088031 0.2446455 ] [1.14410797 0.17525922]\\n\",\n      \"7.209986956838854 [1.023589   0.12307983] [1.01336926 0.09886015] [1.00226365 0.07933181]\\n\",\n      \"10.678851050677565 [1.23783537 0.20936876] [1.22534    0.19914427] [1.20163314 0.17139288]\\n\",\n      \"12.372406636226975 [1.        0.4488869] [1.         0.42012377] [1.         0.38271775]\\n\",\n      \"15.577807191071372 [1.02523902 0.45866847] [1.08312272 0.37623742] [1.20766226 0.23768777]\\n\",\n      \"8.734471938153105 [1.36523752 0.17271522] [1.3134492  0.16957917] [1.25456613 0.14811045]\\n\",\n      \"7.610336442791224 [1.         0.26843504] [1.         0.23356436] [1.         0.17790642]\\n\",\n      \"8.863838171017678 [1.41825983 0.31832998] [1.26990717 0.29132318] [1.19563001 0.22590613]\\n\",\n      \"7.745268162820891 [1.13292471 0.26612377] [1.         0.30604906] [1.         0.21973743]\\n\",\n      \"4.589678527103571 [1.       0.181678] [1.         0.15089917] [1.         0.12282483]\\n\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"4.093994707301205 [1.08729292 0.08806179] [1.06692303 0.06636428] [1.03433912 0.05085375]\\n\",\n      \"4.803442691616833 [1.         0.49318999] [1.         0.45994491] [1.         0.41414422]\\n\",\n      \"8.141877447013025 [1.       0.378827] [1.         0.35032184] [1.         0.31377014]\\n\",\n      \"5.469339275345639 [1.05194881 0.49235425] [1.        0.4531681] [1.         0.36628546]\\n\",\n      \"10.552513200233147 [1.         0.46403819] [1.02642136 0.4204548 ] [1.00651206 0.36114381]\\n\",\n      \"7.912086669941295 [1.4773809  0.02540787] [1.         0.24006762] [1.         0.22141444]\\n\",\n      \"5.1324676655653985 [1.04788493 0.39280235] [1.         0.36007542] [1.         0.29963288]\\n\",\n      \"8.464687671158348 [1.         0.41448182] [1.         0.37280512] [1.         0.30366964]\\n\",\n      \"8.831959464848165 [1.01165928 0.38713134] [1.19958591 0.27802433] [1.         0.27581693]\\n\",\n      \"3.689023065396748 [1.23016991 0.19426718] [1.        0.2285476] [1.         0.19498156]\\n\",\n      \"7.739213941895897 [1.14601354 0.24850415] [1.13882119 0.21963422] [1.0931114 0.1761189]\\n\",\n      \"4.760667671639348 [1.56818228 0.09992981] [1.51742308 0.07806965] [1.19779275 0.06806795]\\n\",\n      \"7.158135476697359 [1.08911821 0.13319368] [1.03077099 0.12626594] [1.10715564 0.0431315 ]\\n\",\n      \"12.517284088482697 [1.37487033 0.32657322] [1.22994953 0.32196908] [1.76583206 0.12341415]\\n\",\n      \"23.968827345650084 [1.         0.41066761] [1.         0.37635411] [1.11578731 0.2836719 ]\\n\",\n      \"9.643650548801416 [1.3668108  0.17932402] [1.14939936 0.20033119] [1.16055901 0.09596568]\\n\",\n      \"3.0839077781933777 [1.         0.22836315] [1.         0.21865558] [1.       0.192191]\\n\",\n      \"9.251162412565186 [1.         0.35180185] [1.06255665 0.2844368 ] [1.        0.2615466]\\n\",\n      \"3.6316623787021496 [1.02245801 0.09782654] [1.00703226 0.0823112 ] [1.00851461 0.05555117]\\n\",\n      \"8.711853634715693 [1.2598993 0.2082685] [1.2496592  0.16908508] [1.18118223 0.11168647]\\n\",\n      \"5.71651001075999 [1.36176414 0.352993  ] [1.15731905 0.34234659] [1.12412015 0.27307774]\\n\",\n      \"8.29629418921876 [1.68012885 0.38619025] [1.61051734 0.33103658] [1.35520496 0.30506579]\\n\",\n      \"11.504291611676422 [1.         0.39222487] [1.         0.35314538] [1.         0.29158934]\\n\",\n      \"6.6256613573123095 [1.         0.42684849] [1.         0.38465708] [1.         0.33054167]\\n\",\n      \"4.404122811378351 [1.12461283 0.18642639] [1.16004867 0.15665071] [1.13695563 0.126951  ]\\n\",\n      \"4.1977340356457615 [1.14924923 0.1855828 ] [1.16115423 0.1522952 ] [1.12954357 0.11237641]\\n\",\n      \"13.427749437316683 [1.21345884 0.29035845] [1.11641823 0.26892753] [1.01931096 0.224942  ]\\n\",\n      \"6.221542448393431 [1.47309959 0.07790591] [1.38973338 0.07566895] [1.27297753 0.06992654]\\n\",\n      \"5.129024047187179 [1.53184879 0.39148192] [1.26484378 0.37476782] [1.         0.34922252]\\n\",\n      \"18.89243184198175 [1.26707603 0.31528262] [1.11724298 0.31020941] [1.17470508 0.18765881]\\n\",\n      \"3.8403752208504867 [1.         0.10222138] [1.         0.08408468] [1.00121213 0.05810218]\\n\",\n      \"6.680468368549407 [1.07202455 0.26639423] [1.00132482 0.2604441 ] [1.         0.22536632]\\n\",\n      \"18.141218087495645 [1.10184668 0.17468229] [1.         0.19923966] [1.19758084 0.05404643]\\n\",\n      \"9.657772216906958 [1.35405939 0.39086437] [1.31007881 0.33573011] [1.24772003 0.28072832]\\n\",\n      \"11.316065831914312 [1.        0.3488616] [1.         0.30311498] [1.         0.22630642]\\n\",\n      \"7.070632914829092 [1.18157271 0.1220394 ] [1.16697923 0.11599484] [1.12292463 0.11577101]\\n\",\n      \"18.838200457610366 [1.32470966 0.16083723] [1.0631238  0.21220793] [1.34402963 0.04964674]\\n\",\n      \"12.797296176162535 [1.        0.3537742] [1.         0.28730011] [1.         0.20137252]\\n\",\n      \"11.898751500782648 [1.0697212  0.21654601] [1.06388615 0.18858074] [1.03748862 0.16535724]\\n\",\n      \"6.9805811813918766 [1.09589326 0.22441316] [1.0947465  0.21055229] [1.08107875 0.1782727 ]\\n\",\n      \"5.784120861661739 [1.12294249 0.17919044] [1.1134077  0.14471304] [1.08794961 0.10354823]\\n\",\n      \"7.325677825701441 [1.10865634 0.18510464] [1.09070859 0.15150712] [1.06172407 0.1120904 ]\\n\",\n      \"12.02613260534443 [1.06568608 0.21866432] [1.06608933 0.18844478] [1.03981067 0.16484466]\\n\",\n      \"7.352216317043304 [1.15006387 0.21032094] [1.14996928 0.19553264] [1.1314296  0.16176135]\\n\",\n      \"5.320183537847537 [1.18137537 0.17844575] [1.36114489 0.09025131] [1.29681948 0.09351459]\\n\",\n      \"12.26260767095218 [1.         0.36513769] [1.12747938 0.26495171] [1.19251323 0.1530195 ]\\n\",\n      \"18.352927282901998 [1.45504738 0.14163784] [1.13158679 0.21087611] [1.55840008 0.02237506]\\n\",\n      \"8.285321285112705 [1.2798559  0.40871858] [1.28288136 0.33896019] [1.21928331 0.28481416]\\n\",\n      \"18.779324428542555 [1.25938515 0.15522197] [1.02509404 0.21349648] [1.40607741 0.03955527]\\n\",\n      \"4.799003114546447 [2.10367131 0.30985834] [1.83451725 0.2763058 ] [1.37548665 0.24560329]\\n\",\n      \"6.225630771591334 [1.65794811 0.2520898 ] [1.28025577 0.2659821 ] [1.37440044 0.11968063]\\n\",\n      \"4.700024647442787 [1.18357376 0.2345532 ] [1.06782747 0.23658913] [1.         0.22019833]\\n\",\n      \"4.89408498110031 [1.02430687 0.10613572] [1.        0.0956662] [1.08135166 0.03583199]\\n\",\n      \"17.788725847584235 [1.         0.37425519] [1.         0.31710407] [1.        0.2181534]\\n\",\n      \"10.424158600402126 [1.08436329 0.46019149] [1.06997713 0.42480411] [1.04629749 0.36934639]\\n\",\n      \"8.856097946871373 [1.45906091 0.09125215] [1.40401634 0.0859803 ] [1.31688577 0.07523226]\\n\",\n      \"10.792286172348518 [1.         0.39910361] [1.         0.37032642] [1.         0.32078214]\\n\",\n      \"13.584834426854492 [1.25335423 0.276269  ] [1.15108925 0.25865294] [1.02356115 0.21844811]\\n\",\n      \"15.965033650655808 [1.         0.38252867] [1.         0.33613805] [1.         0.26891776]\\n\",\n      \"7.952627301951928 [1.43681818 0.40661874] [1.36344022 0.35890101] [1.23925417 0.32109843]\\n\",\n      \"3.9407820659340667 [1.         0.20590304] [1.         0.19076009] [1.         0.16279285]\\n\",\n      \"4.912038362884037 [1.20585882 0.37346754] [1.11804821 0.33153227] [1.14481908 0.23723253]\\n\",\n      \"3.1842083872154476 [1.0278033 0.0922055] [1.01598649 0.07905709] [1.00054893 0.0633468 ]\\n\",\n      \"8.471255354662132 [1.02855462 0.26801312] [1.01591953 0.23587559] [1.         0.17335107]\\n\",\n      \"9.32718279753456 [1.         0.34868665] [1.05331739 0.28080309] [1.         0.25746615]\\n\",\n      \"2.2439102538703395 [1.         0.22132079] [1.         0.21282756] [1.         0.18606632]\\n\",\n      \"9.114192843725885 [1.34652548 0.17587747] [1.15092456 0.19459356] [1.16823653 0.09260065]\\n\",\n      \"22.017509804779 [1.         0.42813566] [1.         0.39476122] [1.         0.33426201]\\n\",\n      \"13.275015059811253 [1.57486061 0.31525481] [1.38372599 0.30776579] [1.84399762 0.12205398]\\n\",\n      \"4.02098845703685 [1.69317948 0.064384  ] [1.69732246 0.03030978] [1.        0.0742018]\\n\",\n      \"6.6819724736065815 [1.11272013 0.12540132] [1.03700053 0.12594568] [1.11223167 0.0400628 ]\\n\",\n      \"6.847572869994068 [1.19752718 0.23147558] [1.17284419 0.20639019] [1.06399262 0.1794046 ]\\n\",\n      \"6.398408999185493 [1.         0.27380186] [1.         0.24343145] [1.         0.20607567]\\n\",\n      \"10.624225941524706 [1.         0.39706335] [1.19270675 0.29190137] [1.19549722 0.21258143]\\n\",\n      \"8.302253204235976 [1.11878213 0.35309838] [1.         0.33641828] [1.        0.2801558]\\n\",\n      \"8.646551408567877 [1.50749827 0.02628092] [1.         0.24240765] [1.         0.22220762]\\n\",\n      \"6.861357749325711 [1.05921528 0.37447684] [1.        0.3469386] [1.         0.28926633]\\n\",\n      \"12.230604804267584 [1.         0.41304914] [1.         0.36614359] [1.         0.29259791]\\n\",\n      \"8.570476761110646 [1.07287849 0.28416565] [1.03545035 0.25449404] [1.01177274 0.19022635]\\n\",\n      \"6.402176973208937 [1.27870191 0.42758859] [1.14803193 0.405654  ] [1.         0.35484543]\\n\",\n      \"4.139203158019425 [1.18878555 0.05133137] [1.08794641 0.05691346] [1.05518293 0.04320333]\\n\",\n      \"4.540621081127601 [1.07535234 0.48389816] [1.         0.45985159] [1.        0.4135118]\\n\",\n      \"13.724873628346414 [1.         0.38667998] [1.         0.34295486] [1.         0.28298764]\\n\",\n      \"6.502438182397647 [1.2524666  0.18903876] [1.         0.23122934] [1.         0.16804861]\\n\",\n      \"9.340013114209574 [1.24206983 0.34294705] [1.15404461 0.31099692] [1.13876964 0.2399859 ]\\n\",\n      \"4.818664433238953 [1.42342594 0.09841502] [1.40093916 0.06054477] [1.32942883 0.04353598]\\n\",\n      \"9.597347521634239 [1.27180802 0.34215124] [1.17644176 0.3101468 ] [1.15953992 0.23879405]\\n\",\n      \"6.531684024239703 [1.17225648 0.45090261] [1.04895417 0.42663058] [1.        0.3545852]\\n\",\n      \"7.902970382306915 [1.        0.3828321] [1.        0.3509311] [1.         0.31221578]\\n\",\n      \"3.572856543109783 [1.52084744 0.40874345] [1.42490507 0.38298202] [1.17717186 0.36366166]\\n\",\n      \"3.7748913600772416 [1.05285317 0.10340779] [1.06472882 0.06956567] [1.03162317 0.05278179]\\n\",\n      \"5.892350057077822 [1.15257191 0.22656358] [1.11848242 0.19031226] [1.12692081 0.11331438]\\n\",\n      \"7.656406946637619 [1.         0.38780176] [1.         0.33601521] [1.         0.26281013]\\n\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"6.0598566365207 [1.47693262 0.29409661] [1.28154119 0.2810116 ] [1.17061817 0.24014481]\\n\",\n      \"6.779243867416825 [1.45616599 0.02661183] [1.1304123  0.16180863] [1.         0.22385116]\\n\",\n      \"14.01559868222701 [1.         0.47228961] [1.         0.43049894] [1.         0.37735332]\\n\",\n      \"8.385329599986996 [1.17041771 0.24841201] [1.14274271 0.22358384] [1.09631354 0.17970203]\\n\",\n      \"16.0227451383004 [1.         0.34002811] [1.         0.32543527] [1.50375927 0.12814954]\\n\",\n      \"6.535906025140597 [1.08511613 0.13623745] [1.0503657  0.12126746] [1.12643462 0.03720826]\\n\",\n      \"5.605220318668054 [1.48784368 0.09234066] [1.40130347 0.08766213] [1.27551035 0.06062238]\\n\",\n      \"9.740896512046973 [1.43057947 0.16835786] [1.18733111 0.19233386] [1.18582144 0.09265421]\\n\",\n      \"24.62920484905838 [1.19588232 0.39341377] [1.30426752 0.33270039] [1.31818841 0.24892338]\\n\",\n      \"9.87361167273439 [1.         0.35505983] [1.03153041 0.28011414] [1.         0.25790464]\\n\",\n      \"10.785432130473906 [1.0100719  0.22856359] [1.         0.20751649] [1.         0.17314927]\\n\",\n      \"2.848822870218504 [1.         0.23324801] [1.         0.22447959] [1.         0.19795285]\\n\",\n      \"7.415554338010629 [1.         0.28088956] [1.         0.24938439] [1.         0.18019241]\\n\",\n      \"3.172457896671065 [1.         0.10193601] [1.         0.07618671] [1.         0.05141725]\\n\",\n      \"9.133480894985293 [1.         0.31089858] [1.         0.27625971] [1.         0.21400311]\\n\",\n      \"5.104756158934721 [1.1664368  0.16518938] [1.12257252 0.15407822] [1.05594699 0.14106247]\\n\",\n      \"10.641007089861258 [1.19738839 0.39723323] [1.07591154 0.37229626] [1.         0.31978612]\\n\",\n      \"8.219070864290465 [1.         0.46583664] [1.         0.41833496] [1.         0.34931202]\\n\",\n      \"13.763388211606413 [2.20533378 0.29447472] [2.02780724 0.26782876] [1.84372838 0.21020024]\\n\",\n      \"4.755651120898854 [2.0131304  0.28377263] [1.74854002 0.25357947] [1.49874913 0.19254194]\\n\",\n      \"10.599803234656594 [1.        0.1375559] [1.         0.12323069] [1.066659   0.05801778]\\n\",\n      \"13.2840212012749 [1.25982194 0.2601624 ] [1.1527094  0.24049561] [1.04625481 0.20054501]\\n\",\n      \"16.480292791150053 [1.38280319 0.30029424] [1.18113575 0.30240947] [1.22205371 0.17240414]\\n\",\n      \"11.512793233721805 [1.        0.4172398] [1.         0.37968987] [1.         0.31791679]\\n\",\n      \"6.296638729448851 [1.20201688 0.19721292] [1.17676438 0.17814056] [1.13224075 0.14145692]\\n\",\n      \"4.84576752198711 [1.         0.11971112] [1.         0.09699513] [1.03263381 0.0528492 ]\\n\",\n      \"7.508196011967394 [1.         0.28483445] [1.        0.2603375] [1.         0.22496615]\\n\",\n      \"7.298552585311891 [1.31441801 0.15227896] [1.33130135 0.09860831] [1.24900776 0.08241866]\\n\",\n      \"5.974410989334576 [1.         0.32068934] [1.         0.30168836] [1.         0.26689476]\\n\",\n      \"4.5573463750190335 [1.8507893  0.33166136] [1.62845947 0.29869509] [1.31513428 0.25783196]\\n\",\n      \"8.650592174777097 [1.03185706 0.22336286] [1.01987652 0.18605172] [1.0036008  0.14088206]\\n\",\n      \"8.041229519353713 [1.1450843  0.21811508] [1.12986112 0.20591502] [1.10609819 0.17487872]\\n\",\n      \"9.202195094288346 [1.19192202 0.21032271] [1.1642241  0.19891529] [1.1283512  0.16956792]\\n\",\n      \"8.474940572227336 [1.02833622 0.2146775 ] [1.01500988 0.18189212] [1.01208312 0.13188736]\\n\",\n      \"4.679237754906878 [1.94209567 0.31654485] [1.72724612 0.27689186] [1.35220909 0.23855317]\\n\",\n      \"8.844457686474382 [1.32179573 0.17843325] [1.30613842 0.14585346] [1.22776403 0.13407708]\\n\",\n      \"12.67436564252188 [1.         0.37325004] [1.09665554 0.28260383] [1.21694107 0.14876266]\\n\",\n      \"19.001589176677932 [1.37484114 0.15854447] [1.10288975 0.21824572] [1.56019944 0.02856289]\\n\",\n      \"6.277685898917732 [1.         0.32013614] [1.         0.30090605] [1.         0.26586097]\\n\",\n      \"7.4658412461713715 [1.15103985 0.16408361] [1.2410312  0.11025922] [1.18201128 0.11624398]\\n\",\n      \"9.467478293219816 [1.        0.2930279] [1.        0.2693233] [1.         0.23600485]\\n\",\n      \"5.421451808832461 [1.        0.1161232] [1.         0.09394996] [1.04443591 0.04547173]\\n\",\n      \"9.910064275264103 [1.         0.42451735] [1.         0.38860386] [1.         0.32880158]\\n\",\n      \"6.639947118006169 [1.12744732 0.21638269] [1.15605827 0.17751135] [1.14616089 0.12837949]\\n\",\n      \"14.319552015517354 [1.27684374 0.30564773] [1.17582248 0.29122776] [1.24391082 0.15795376]\\n\",\n      \"13.057527436425602 [1.22696498 0.26925168] [1.1368579  0.25169427] [1.02884318 0.21011612]\\n\",\n      \"20.160947220029147 [1.         0.44599927] [1.15394538 0.37883883] [1.42786786 0.25589613]\\n\",\n      \"10.326202242694293 [1.         0.13608993] [1.         0.12089174] [1.05266966 0.06732383]\\n\",\n      \"5.121533611485272 [1.19741657 0.14656392] [1.13724815 0.14571089] [1.0618522  0.13704809]\\n\",\n      \"8.675159652180747 [1.         0.46285123] [1.         0.40846725] [1.         0.33408502]\\n\",\n      \"13.511408671173704 [1.08802566 0.40736015] [1.05257108 0.37110429] [1.         0.31684522]\\n\",\n      \"10.480884661942895 [1.        0.2838687] [1.         0.25676719] [1.         0.19769731]\\n\",\n      \"3.0063233366563167 [1.         0.10283038] [1.         0.07852243] [1.         0.05381072]\\n\",\n      \"4.8062820810801945 [1.         0.26798569] [1.01153599 0.22781085] [1.10480981 0.11761344]\\n\",\n      \"2.6561197373573475 [1.         0.23372729] [1.         0.22384593] [1.         0.19910501]\\n\",\n      \"9.31748771964591 [1.         0.36031071] [1.04516442 0.27846026] [1.         0.26007056]\\n\",\n      \"7.0285335908427315 [1.09811464 0.13826218] [1.02092417 0.13659478] [1.11624688 0.04206582]\\n\",\n      \"5.881240986614894 [1.54909476 0.11237151] [1.48535367 0.08971277] [1.26212689 0.07374479]\\n\",\n      \"13.600209256516814 [1.42836667 0.28293441] [1.21331278 0.29111539] [1.80904706 0.09467976]\\n\",\n      \"7.818621978978676 [1.03597704 0.24015199] [1.         0.22871728] [1.         0.19159323]\\n\",\n      \"10.472367783813572 [1.02827877 0.30487784] [1.04517289 0.26376533] [1.06375552 0.19330964]\\n\",\n      \"8.330206486619348 [1.48694105 0.02501354] [1.02458533 0.22123065] [1.         0.22381229]\\n\",\n      \"14.011543965087405 [1.         0.46722971] [1.19874565 0.37258221] [1.15463464 0.32165357]\\n\",\n      \"10.097498540708527 [1.         0.38901564] [1.0837705  0.30853196] [1.08180111 0.23599267]\\n\",\n      \"7.660246741096868 [1.1920754 0.2060254] [1.14358882 0.17106914] [1.13787037 0.09124709]\\n\",\n      \"10.13680153648991 [1.         0.37900476] [1.         0.34928027] [1.         0.31010357]\\n\",\n      \"11.010190874450128 [1.         0.42048089] [1.         0.39229723] [1.         0.35350634]\\n\",\n      \"3.529463722512494 [1.06223271 0.09723751] [1.07215913 0.06952176] [1.04118846 0.04876519]\\n\",\n      \"4.223505986082864 [1.18576262 0.45114874] [1.08547425 0.42215501] [1.         0.36016624]\\n\",\n      \"4.030116636362381 [1.35428828 0.12524804] [1.33834494 0.08534257] [1.27148119 0.06374946]\\n\",\n      \"4.40204128999036 [1.38641626 0.14479881] [1.         0.22534081] [1.         0.14847461]\\n\",\n      \"8.129061521692863 [1.         0.31855541] [1.         0.32389108] [1.         0.23611306]\\n\",\n      \"9.078389556832173 [1.         0.39043993] [1.         0.34562551] [1.         0.28046481]\\n\",\n      \"10.922586828862263 [1.         0.44638453] [1.01467007 0.40765968] [1.01647234 0.35099867]\\n\",\n      \"7.446449166826484 [1.04170848 0.31158162] [1.         0.28492783] [1.         0.21417657]\\n\",\n      \"14.643796412881152 [1.         0.38914462] [1.        0.3448852] [1.         0.27999371]\\n\",\n      \"4.633343637535566 [1.13118718 0.07027916] [1.06857895 0.06254089] [1.0429486  0.04713405]\\n\",\n      \"3.5184877298587742 [1.31444618 0.42958184] [1.19518759 0.40567312] [1.05849899 0.37707793]\\n\",\n      \"7.104549282939431 [1.19273211 0.22548999] [1.15754557 0.2041024 ] [1.06418321 0.17410454]\\n\",\n      \"5.184466709805151 [1.         0.24197917] [1.         0.21527012] [1.         0.17588522]\\n\",\n      \"8.573719732087964 [1.         0.39229011] [1.24729093 0.27826192] [1.17925231 0.21276987]\\n\",\n      \"6.8874927234004275 [1.52488306 0.02261898] [1.17100034 0.1418223 ] [1.         0.25747024]\\n\",\n      \"5.394806335622542 [1.15188767 0.35389505] [1.0431721  0.33231426] [1.         0.28469106]\\n\",\n      \"9.012881614996276 [1.32162831 0.19644876] [1.15275734 0.20352503] [1.16051101 0.10955896]\\n\",\n      \"13.107302558559942 [1.44262958 0.34513576] [1.31695146 0.31758483] [1.73725417 0.14804349]\\n\",\n      \"6.356884312158647 [1.43232794 0.33397626] [1.19806779 0.33130407] [1.05299997 0.31444563]\\n\",\n      \"4.361883497071163 [1.49528507 0.12212512] [1.43113383 0.10107   ] [1.30541974 0.06793306]\\n\",\n      \"5.908451486291657 [1.27620447 0.16820832] [1.21399752 0.15641256] [1.16399164 0.14727655]\\n\",\n      \"6.348470636794943 [1.06501106 0.14323209] [1.02946656 0.12921386] [1.09469555 0.05045214]\\n\",\n      \"4.08822226206924 [1.         0.11199233] [1.         0.09263176] [1.         0.06998212]\\n\",\n      \"9.446789209925925 [1.         0.29673502] [1.         0.26381312] [1.         0.20178498]\\n\",\n      \"9.125675380237348 [1.         0.29819512] [1.         0.25752205] [1.         0.17762492]\\n\",\n      \"10.097362922703098 [1.         0.35659346] [1.0360234  0.27651893] [1.         0.25969843]\\n\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"3.7685768629199634 [1.         0.25506584] [1.         0.24477779] [1.         0.22149523]\\n\",\n      \"12.170617812549592 [1.         0.43683909] [1.         0.40373471] [1.         0.34853848]\\n\",\n      \"12.144325023976737 [1.16927628 0.29559312] [1.09902816 0.270262  ] [1.02206552 0.22197826]\\n\",\n      \"7.8854542794127775 [1.07043808 0.4451352 ] [1.         0.41019962] [1.         0.34149892]\\n\",\n      \"5.747458873321492 [1.04981411 0.24164911] [1.07435791 0.20507094] [1.11552737 0.13974482]\\n\",\n      \"4.859785336888636 [1.32823536 0.35237545] [1.17587717 0.31881741] [1.13394924 0.24219468]\\n\",\n      \"6.1713753736067 [1.14537152 0.22041008] [1.12506424 0.20464841] [1.0743455  0.18970252]\\n\",\n      \"6.389784160204872 [1.00705619 0.12245689] [1.         0.10373429] [1.05016515 0.05450275]\\n\",\n      \"4.843863878027568 [1.54593086 0.06172297] [1.4261355  0.06443288] [1.28244067 0.06144386]\\n\",\n      \"9.944508788523496 [1.         0.43965845] [1.         0.40195183] [1.         0.34397386]\\n\",\n      \"7.146629277531506 [1.08914226 0.18917886] [1.0781795  0.15343313] [1.04779075 0.11422093]\\n\",\n      \"6.251557015695687 [1.1220385  0.13225557] [1.07936011 0.14064204] [1.04056439 0.14059275]\\n\",\n      \"20.189889828617833 [1.21265497 0.20684389] [1.         0.24336379] [1.22663225 0.08532478]\\n\",\n      \"18.547647853136688 [1.         0.19281915] [1.         0.19805366] [1.         0.07785673]\\n\",\n      \"4.971371152995008 [1.72292261 0.3663594 ] [1.63078777 0.32088663] [1.19328626 0.30544493]\\n\",\n      \"21.805115878754577 [1.         0.24911644] [1.       0.240072] [1.07342328 0.09893232]\\n\",\n      \"4.84573124042842 [1.52474305 0.3949025 ] [1.33786432 0.36768522] [1.05297759 0.34824244]\\n\",\n      \"6.826724211021385 [1.21798918 0.11667315] [1.19300584 0.11399387] [1.156586   0.10732942]\\n\",\n      \"20.435224852598804 [1.19998484 0.21229644] [1.         0.24489777] [1.21058254 0.08707851]\\n\",\n      \"5.749119594419538 [1.         0.34264062] [1.         0.31544507] [1.         0.26280049]\\n\",\n      \"6.275567095482205 [1.05266183 0.22415031] [1.15728881 0.12948833] [1.11230191 0.09487598]\\n\",\n      \"6.247209539950038 [1.51867317 0.06643887] [1.43394424 0.06542257] [1.32097504 0.06019745]\\n\",\n      \"9.432280196342646 [1.         0.44425828] [1.         0.40673676] [1.         0.34992079]\\n\",\n      \"14.049938924227257 [1.80710662 0.15977388] [1.60599774 0.16410479] [1.6769296  0.04289388]\\n\",\n      \"5.819631747382949 [1.         0.27539881] [1.         0.25182073] [1.         0.22024107]\\n\",\n      \"4.587155863561414 [1.34754112 0.3552102 ] [1.19470642 0.32066332] [1.08583461 0.26067415]\\n\",\n      \"8.347046385947243 [1.         0.46269721] [1.         0.40744864] [1.         0.33118333]\\n\",\n      \"4.614603587169393 [1.02019149 0.26165914] [1.09545259 0.21167945] [1.28019071 0.10303905]\\n\",\n      \"12.14178602981697 [1.14771746 0.29656745] [1.08535981 0.27125944] [1.00982694 0.22497728]\\n\",\n      \"11.150434599737167 [1.        0.4601252] [1.         0.42149625] [1.         0.35909137]\\n\",\n      \"9.06735902747434 [1.         0.36828413] [1.04873482 0.28796787] [1.        0.2681195]\\n\",\n      \"3.4308747939276114 [1.         0.10330157] [1.         0.08168463] [1.         0.05886519]\\n\",\n      \"7.876642007206043 [1.07691179 0.29187841] [1.06225738 0.25031252] [1.         0.18798912]\\n\",\n      \"8.160522686210383 [1.21685794 0.38932491] [1.07683591 0.37653055] [1.         0.34801735]\\n\",\n      \"6.04527020655973 [1.25256324 0.1766149 ] [1.21054966 0.15639779] [1.16583198 0.14706248]\\n\",\n      \"12.389366504458941 [1.42720437 0.34713776] [1.33544325 0.30839149] [1.70465889 0.15133562]\\n\",\n      \"8.717500487207278 [1.31798132 0.1827578 ] [1.14294955 0.19668818] [1.15604166 0.10476818]\\n\",\n      \"7.832932836765893 [1.63711141 0.02248008] [1.19689107 0.16297951] [1.         0.27430909]\\n\",\n      \"4.731645658724988 [1.26138902 0.32859469] [1.11767305 0.30616558] [1.05693017 0.25837288]\\n\",\n      \"7.232895693505282 [1.         0.37743186] [1.         0.32883281] [1.        0.2575448]\\n\",\n      \"6.646969739152888 [1.08673304 0.19923439] [1.04800513 0.18852304] [1.0031643  0.16843263]\\n\",\n      \"7.003046278484913 [1.17703727 0.22536045] [1.14921878 0.20254484] [1.08429754 0.16369367]\\n\",\n      \"4.04756324291338 [1.16496385 0.05830398] [1.06852961 0.06222365] [1.04884139 0.04064035]\\n\",\n      \"4.45505703871402 [1.22844084 0.45789614] [1.04784352 0.44878295] [1.         0.40978731]\\n\",\n      \"3.0134300763434836 [1.95908206 0.32926769] [1.74670636 0.30533238] [1.39009546 0.2623041 ]\\n\",\n      \"13.348832382173757 [1.         0.39161291] [1.      0.35091] [1.         0.28604537]\\n\",\n      \"10.304710295664696 [1.         0.45893432] [1.         0.41864383] [1.         0.35694521]\\n\",\n      \"9.785908436998039 [1.08372382 0.32098789] [1.01837139 0.30181823] [1.         0.23469834]\\n\",\n      \"8.976889411683826 [1.        0.3959776] [1.         0.35170545] [1.02194185 0.27629988]\\n\",\n      \"6.860965238430664 [1.14332745 0.21776023] [1.         0.23488931] [1.         0.17477226]\\n\",\n      \"15.572344511846968 [1.69924381 0.12356604] [2.00048359 0.02007715] [1.6883586  0.05150387]\\n\",\n      \"8.262607543929457 [1.07895506 0.37505895] [1.02074705 0.33990409] [1.0603653 0.25874  ]\\n\",\n      \"4.912597143120675 [1.         0.22822034] [1.         0.19841094] [1.         0.12697328]\\n\",\n      \"6.363851780499002 [1.08254776 0.28384135] [1.02512182 0.25842776] [1.00823914 0.19289162]\\n\",\n      \"12.30228201372977 [1.         0.43639514] [1.         0.40232694] [1.         0.34632636]\\n\",\n      \"11.588900327416342 [1.        0.3693939] [1.        0.3382938] [1.         0.29726383]\\n\",\n      \"7.533132769290531 [1.13759613 0.35266774] [1.18449953 0.29064698] [1.09630072 0.25145613]\\n\",\n      \"11.289702293060968 [1.50749144 0.16226019] [1.45744929 0.13049412] [1.23527793 0.14069548]\\n\",\n      \"4.661812373082756 [1.75910219 0.35953738] [1.53991883 0.33878098] [1.05060127 0.3295266 ]\\n\",\n      \"5.4460323603950425 [1.         0.24312404] [1.         0.21168479] [1.         0.17354187]\\n\",\n      \"8.706177881674797 [1.0799126  0.36515226] [1.01270513 0.32855653] [1.         0.27853921]\\n\",\n      \"8.00077235114351 [1.13150565 0.24845463] [1.13561437 0.21600817] [1.11821358 0.15460734]\\n\",\n      \"8.781319222545337 [1.         0.36041255] [1.         0.32616131] [1.         0.27616279]\\n\",\n      \"5.785696866658025 [1.49238935 0.02230037] [1.16906831 0.13467273] [1.         0.24162001]\\n\",\n      \"8.478908450888166 [1.         0.37765749] [1.         0.32574951] [1.         0.25176544]\\n\",\n      \"5.8693524852863375 [1.         0.30947209] [1.00026833 0.27751733] [1.04401617 0.21315743]\\n\",\n      \"9.488132419170112 [1.28201812 0.20183249] [1.13807666 0.20555704] [1.15478091 0.11431506]\\n\",\n      \"6.265242354450486 [1.17191284 0.11989237] [1.09503003 0.11605688] [1.1446114  0.02895524]\\n\",\n      \"5.1572482448606625 [1.51466321 0.11179498] [1.47477233 0.08628483] [1.22591373 0.0704102 ]\\n\",\n      \"10.877991402534544 [1.         0.32529559] [1.         0.28170299] [1.         0.22808252]\\n\",\n      \"12.571582373210557 [1.73963119 0.25463019] [1.62983973 0.24014169] [2.00208608 0.06696168]\\n\",\n      \"6.39314313081427 [1.32507248 0.31924145] [1.6659829  0.15794568] [1.29657599 0.13307339]\\n\",\n      \"6.969672823392187 [1.1135408  0.26977364] [1.14255001 0.20186158] [1.12263319 0.12106629]\\n\",\n      \"2.8356203449588677 [1.         0.24473362] [1.         0.23107451] [1.         0.20297414]\\n\",\n      \"10.417356223180978 [1.         0.35744425] [1.         0.28276206] [1.         0.26162718]\\n\",\n      \"13.750913044333586 [1.25206536 0.27687467] [1.16182884 0.25240209] [1.05208969 0.20944306]\\n\",\n      \"12.150945773550827 [1.         0.38371884] [1.         0.34285822] [1.       0.284928]\\n\",\n      \"5.389336374411643 [1.16121849 0.17694264] [1.13249659 0.16548833] [1.07270058 0.14587459]\\n\",\n      \"10.523329017051536 [1.         0.42764285] [1.         0.38407481] [1.         0.31660827]\\n\",\n      \"8.269418511118195 [1.         0.46234438] [1.         0.41054863] [1.         0.33883393]\\n\",\n      \"9.167433277613362 [1.         0.29017861] [1.         0.26614638] [1.         0.23131604]\\n\",\n      \"5.5609186219361595 [1.         0.12462872] [1.         0.10465619] [1.01838654 0.06773848]\\n\",\n      \"9.594943678759213 [1.         0.42423773] [1.         0.38552999] [1.         0.32468429]\\n\",\n      \"5.3791976022018515 [1.54201572 0.06628701] [1.40167858 0.07165457] [1.2690287  0.06592292]\\n\",\n      \"19.77399570300559 [1.07824599 0.38281923] [1.00398616 0.36607595] [1.1181896  0.23795898]\\n\",\n      \"9.580546761468955 [1.37306584 0.17577431] [1.30179416 0.16933436] [1.21704513 0.14663145]\\n\",\n      \"13.33828004159378 [1.         0.36330007] [1.         0.30110656] [1.07022844 0.1971349 ]\\n\",\n      \"7.26545664572203 [1.00320889 0.27332249] [1.08149899 0.20519183] [1.13360372 0.12610142]\\n\",\n      \"4.020766820685916 [1.84646009 0.34098613] [1.67902187 0.29709569] [1.24429486 0.27404391]\\n\",\n      \"19.273838139796524 [1.23985283 0.15511182] [1.07159674 0.18082891] [1.38035117 0.04312197]\\n\",\n      \"18.747154847936542 [1.2879973  0.17396842] [1.02469479 0.2312224 ] [1.38216191 0.04543679]\\n\",\n      \"7.687501494777646 [1.18911408 0.13068791] [1.21358109 0.10652954] [1.1600769  0.11167813]\\n\",\n      \"17.41626227161925 [1.24286567 0.17648605] [1.         0.23681838] [1.37625362 0.04062536]\\n\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"4.608339372745333 [1.74113426 0.34798842] [1.52317166 0.3191022 ] [1.16724291 0.2942582 ]\\n\",\n      \"18.875741071133284 [1.2543283  0.15648117] [1.06798541 0.18717109] [1.34530144 0.04991616]\\n\",\n      \"11.148313150594289 [1.         0.37511565] [1.09509921 0.28287656] [1.14942929 0.17369543]\\n\",\n      \"7.233035561460236 [1.08561326 0.2354706 ] [1.20422952 0.14769201] [1.1853493  0.09605807]\\n\",\n      \"8.863229059881018 [1.41297349 0.19649819] [1.33371269 0.18662474] [1.22792143 0.15976778]\\n\",\n      \"19.45745868435802 [1.49994807 0.26682565] [1.26711735 0.27545718] [1.34188963 0.14555973]\\n\",\n      \"9.468170012731786 [1.         0.42522692] [1.         0.38323518] [1.         0.32367061]\\n\",\n      \"5.868515168416221 [1.52897243 0.06848402] [1.38364141 0.07468658] [1.25120255 0.06850521]\\n\",\n      \"6.000131214338976 [1.         0.12002132] [1.        0.0980278] [1.0169329  0.05973255]\\n\",\n      \"7.438143572576681 [1.         0.28502836] [1.         0.26303692] [1.         0.22860402]\\n\",\n      \"5.289794627964023 [1.1335567  0.21179734] [1.11170322 0.19507687] [1.07163701 0.16266489]\\n\",\n      \"7.964680770390557 [1.14781153 0.42575943] [1.         0.40190126] [1.         0.33838822]\\n\",\n      \"10.032075253179388 [1.19303526 0.34074064] [1.17790367 0.30367674] [1.13926281 0.24261738]\\n\",\n      \"9.507332791209954 [1.         0.36297141] [1.12549776 0.27003628] [1.00850659 0.26411807]\\n\",\n      \"2.819872748989664 [1.         0.23983558] [1.        0.2266287] [1.        0.1979991]\\n\",\n      \"7.495457263473503 [1.23996667 0.23219605] [1.25828273 0.17540812] [1.2256135  0.09706519]\\n\",\n      \"3.3035183593098116 [1.05401011 0.0786346 ] [1.0064354 0.0834124] [1.00066125 0.06275393]\\n\",\n      \"10.982202783668944 [1.01541626 0.32391523] [1.04335354 0.26363175] [1.02417349 0.21390277]\\n\",\n      \"5.444681012424603 [1.21112693 0.34450231] [1.49266707 0.19883686] [1.21760034 0.16245562]\\n\",\n      \"5.71585520980478 [1.38514994 0.07153152] [1.26599542 0.07115983] [1.28125242 0.00203232]\\n\",\n      \"4.78489739965205 [1.42650945 0.09035256] [1.35666958 0.08265567] [1.25431092 0.05747103]\\n\",\n      \"9.338768434828841 [1.28952457 0.20208728] [1.12484015 0.21011617] [1.14263045 0.1109362 ]\\n\",\n      \"6.073680103146996 [1.00675102 0.26630249] [1.02387754 0.23440817] [1.02017898 0.1963962 ]\\n\",\n      \"8.881113087703643 [1.         0.39147761] [1.13747849 0.30262937] [1.13791602 0.21990965]\\n\",\n      \"6.288096344683792 [1.22002553 0.34162078] [1.05898669 0.32589906] [1.         0.28077121]\\n\",\n      \"6.370692375869698 [1.47119504 0.02273981] [1.14891519 0.14920475] [1.        0.2303363]\\n\",\n      \"7.1979636175223405 [1.11389871 0.23466379] [1.10305907 0.20912161] [1.06224552 0.16659362]\\n\",\n      \"5.849074255529509 [1.         0.24786564] [1.         0.22290166] [1.         0.18581435]\\n\",\n      \"3.6782213024410426 [1.54906327 0.38499582] [1.38170487 0.35972071] [1.05013607 0.33461307]\\n\",\n      \"4.76170856930791 [1.13572491 0.38745464] [1.03104303 0.37008839] [1.         0.32496164]\\n\",\n      \"5.155595609262272 [1.         0.49398408] [1.         0.46194917] [1.         0.41697888]\\n\",\n      \"6.642882304921422 [1.59828797 0.1202163 ] [1.4459217  0.11160453] [1.19351043 0.12236486]\\n\",\n      \"4.814149145007596 [1.06481287 0.11952267] [1.08056103 0.08142804] [1.04100738 0.05845668]\\n\",\n      \"5.019245705742924 [1.08739817 0.26074022] [1.06015233 0.22453726] [1.08866459 0.14105046]\\n\",\n      \"9.16744872104858 [1.23419148 0.43999962] [1.20295263 0.398843  ] [1.098967   0.34742564]\\n\",\n      \"4.786575682308314 [1.         0.28540823] [1.         0.27403455] [1.         0.18088439]\\n\",\n      \"8.250491528487505 [1.         0.39301486] [1.         0.34821779] [1.         0.27898305]\\n\",\n      \"8.866099271504389\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"import cv2\\n\",\n    \"import os\\n\",\n    \"from os import listdir\\n\",\n    \"from os.path import isfile, join\\n\",\n    \"from PIL import Image as Image\\n\",\n    \"import numpy as np\\n\",\n    \"\\n\",\n    \"from scipy.optimize import curve_fit\\n\",\n    \"\\n\",\n    \"def relit(x, a, b):\\n\",\n    \"    return (a * x.astype(np.float)/255 + b)*255\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"from matplotlib import pyplot as plt\\n\",\n    \"def plshow(im,title='MINE'):\\n\",\n    \"    if len(im.shape)>2:\\n\",\n    \"  #      plt.imshow(cv2.cvtColor(im, cv2.COLOR_BGR2RGB))\\n\",\n    \"        plt.imshow(im)\\n\",\n    \"    else:\\n\",\n    \"        plt.imshow(im,cmap='gray')\\n\",\n    \"    plt.title(title)\\n\",\n    \"    plt.rcParams[\\\"figure.figsize\\\"] = (80,12)\\n\",\n    \"    plt.show()\\n\",\n    \"\\n\",\n    \"sd_path = 'dataset/ISTD/train_A'\\n\",\n    \"mask_path = 'dataset/ISTD/train_B'\\n\",\n    \"sdfree_path = 'dataset/ISTD/train_C_fixed_ours'\\n\",\n    \"\\n\",\n    \"out = 'dataset/ISTD/train_params/'\\n\",\n    \"if not os.path.exists(out):\\n\",\n    \"    os.makedirs(out)\\n\",\n    \"\\n\",\n    \"im_list  =  [f for f in listdir(sd_path) if isfile(join(sd_path, f)) and f.endswith('png')]\\n\",\n    \"print(len(im_list),im_list[0])\\n\",\n    \"kernel = np.ones((5,5),np.uint8)\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"def im_relit(Rpopt,Gpopt,Bpopt,dump):\\n\",\n    \"    #fc this shit\\n\",\n    \"    sdim = dump.copy()\\n\",\n    \"    sdim.setflags(write=1)\\n\",\n    \"    sdim = sdim.astype(np.float)\\n\",\n    \"    sdim[:,:,0] = (sdim[:,:,0]/255) * Rpopt[0] + Rpopt[1]\\n\",\n    \"    sdim[:,:,1] = (sdim[:,:,1]/255) * Gpopt[0] + Gpopt[1]\\n\",\n    \"    sdim[:,:,2] = (sdim[:,:,2]/255) * Bpopt[0] + Bpopt[1]\\n\",\n    \"    sdim = sdim*255\\n\",\n    \"   # print(np.amin(sdim),np.amax(sdim))\\n\",\n    \"    return sdim\\n\",\n    \"\\n\",\n    \"errors= []\\n\",\n    \"for im in im_list:\\n\",\n    \"    sd = np.asarray(Image.open(join(sd_path,im)))\\n\",\n    \"    mean_sdim = np.mean(sd,axis=2)\\n\",\n    \"    \\n\",\n    \"    mask_ori = np.asarray(Image.open(join(mask_path,im)))\\n\",\n    \"    mask = cv2.erode(mask_ori ,kernel,iterations = 2)\\n\",\n    \"\\n\",\n    \"    \\n\",\n    \"    sdfree = np.asarray(Image.open(join(sdfree_path,im)))\\n\",\n    \"    mean_sdfreeim = np.mean(sdfree,axis=2)\\n\",\n    \"    \\n\",\n    \"    #pixels for regression funtion\\n\",\n    \"    i, j = np.where(np.logical_and(np.logical_and(np.logical_and(mask>=1,mean_sdim>5),mean_sdfreeim<230),np.abs(mean_sdim-mean_sdfreeim)>10))\\n\",\n    \"\\n\",\n    \"    source = sd*0\\n\",\n    \"    source[tuple([i,j])] = sd[tuple([i,j])] \\n\",\n    \"    target = sd*0\\n\",\n    \"    target[tuple([i,j])]= sdfree[tuple([i,j])]\\n\",\n    \"    \\n\",\n    \"    R_s = source[:,:,0][tuple([i,j])]\\n\",\n    \"    G_s = source[:,:,1][tuple([i,j])]\\n\",\n    \"    B_s = source[:,:,2][tuple([i,j])]\\n\",\n    \"    \\n\",\n    \"    R_t = target[:,:,0][tuple([i,j])]\\n\",\n    \"    G_t = target[:,:,1][tuple([i,j])]\\n\",\n    \"    B_t = target[:,:,2][tuple([i,j])]\\n\",\n    \"    \\n\",\n    \"    c_bounds = [[1,-0.1],[10,0.5]]\\n\",\n    \"\\n\",\n    \"    \\n\",\n    \"    Rpopt, pcov = curve_fit(relit, R_s, R_t,bounds=c_bounds)\\n\",\n    \"    Gpopt, pcov = curve_fit(relit, G_s, G_t,bounds=c_bounds)\\n\",\n    \"    Bpopt, pcov = curve_fit(relit, B_s, B_t,bounds=c_bounds)\\n\",\n    \"    \\n\",\n    \"    \\n\",\n    \"    relitim = im_relit(Rpopt,Gpopt,Bpopt,sd)\\n\",\n    \"    \\n\",\n    \"    #final = sd.copy()\\n\",\n    \"    #final[tuple([i,j])] = relitim[tuple([i,j])]\\n\",\n    \"    #final[final>255] =255\\n\",\n    \"    #final[final<0] = 0\\n\",\n    \"\\n\",\n    \"    #plshow(final)\\n\",\n    \"    error = np.mean(np.abs(relitim[tuple([i,j])].astype(np.float) - sdfree[tuple([i,j])]).astype(np.float))\\n\",\n    \"    print(error,Rpopt,Gpopt,Bpopt)\\n\",\n    \"    f = open(join(out,im+'.txt'),\\\"a\\\")\\n\",\n    \"    f.write(\\\"%f %f %f %f %f %f\\\"%(Rpopt[1],Rpopt[0],Gpopt[1],Gpopt[0],Bpopt[1],Bpopt[0]))\\n\",\n    \"    f.close()\\n\",\n    \"             \\n\",\n    \"  #  print(error)\\n\",\n    \"    errors.append(error)\\n\",\n    \"                    \\n\",\n    \"    \\n\",\n    \"print(np.mean(errors))\\n\",\n    \"#no bound - 8.55\\n\",\n    \"#### y_bound ###error\\n\",\n    \"#    0.5        8.86\\n\",\n    \"#    0.1        15.692271753155671    \\n\",\n    \"#    0.25       10.830443545867785\\n\",\n    \"#    1          8.86\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": []\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"Python 3\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.8.5\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 4\n}\n"
  },
  {
    "path": "data_processing/test_mask_generation.py",
    "content": "import os\n\nimport cv2\nimport matplotlib.pyplot as plt\n\nimport PIL.Image as Image\nfrom skimage.filters import threshold_otsu\nfrom skimage import filters\nimport numpy as np\nfrom skimage.morphology import disk\n\n# get mask from shadow image and shadow free image\n\ntrain_A_path = '../data/ISTD+/test_A'\ntrain_C_path = '../data/ISTD+/test_C'\nroot_path = os.listdir(train_A_path)\n\nfor file in root_path:\n    s_name = os.path.join(train_A_path, file)\n    sf_name = os.path.join(train_C_path, file)\n\n    s_img = Image.open(s_name).convert(\"L\")\n    s_img = np.array(s_img).astype(np.float32)\n    sf_img = Image.open(sf_name).convert(\"L\")\n    sf_img = np.array(sf_img).astype(np.float32)\n    diff = (np.asarray(sf_img, dtype='float32') - np.asarray(s_img, dtype='float32'))\n    # diff[diff < 0] = 0\n    L = threshold_otsu(diff)\n    mask = np.float32(diff > L) * 255\n    mask = filters.median(mask, disk(5))\n    mask = Image.fromarray(mask).convert(\"L\")\n    # plt.imshow(mask)\n    # plt.show()\n    mask.save('../data/ISTD+/test_mask/' + file)\n"
  },
  {
    "path": "models/Fusion_model.py",
    "content": "import torch\nfrom collections import OrderedDict\nimport time\nimport numpy as np\nimport os\nimport torch.nn.functional as F\nimport torch.nn as nn\nfrom util.image_pool import ImagePool\nfrom .base_model import BaseModel\nfrom . import networks\nimport util.util as util\nfrom .distangle_model import DistangleModel\nfrom PIL import ImageOps,Image\nimport cv2\n\n\ndef tensor2im(input_image, imtype=np.uint8):\n    \"\"\"\"Converts a Tensor array into a numpy image array.\n    Parameters:\n        input_image (tensor) --  the input image tensor array\n        imtype (type)        --  the desired type of the converted numpy array\n    \"\"\"\n    if not isinstance(input_image, np.ndarray):\n        if isinstance(input_image, torch.Tensor):  # get the data from a variable\n            image_tensor = input_image.data\n        else:\n            return input_image\n        image_numpy = image_tensor.cpu().float().numpy()  # convert it into a numpy array\n        if image_numpy.shape[0] == 1:  # grayscale to RGB\n            image_numpy = np.tile(image_numpy, (3, 1, 1))\n            # image_numpy = image_numpy.convert('L')\n        image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0  # post-processing: tranpose and scaling\n    else:  # if it is a numpy array, do nothing\n        image_numpy = input_image\n    return np.clip(image_numpy, 0, 255).astype(imtype)\n\n\nclass L_TV(nn.Module):\n    def __init__(self):\n        super(L_TV, self).__init__()\n    def forward(self, x):\n        _, _, h, w = x.size()\n        count_h = (h - 1) * w\n        count_w = (w - 1) * h\n\n        h_tv = torch.pow(x[:, :, 1:, :] - x[:, :, :h - 1, :], 2).sum()\n        w_tv = torch.pow(x[:, :, :, 1:] - x[:, :, :, :w - 1], 2).sum()\n        return (h_tv / count_h + w_tv / count_w) / 2.0\n\n\nclass GradientLoss(nn.Module):\n    def __init__(self, loss_weight=1.0, reduction='mean'):\n        super(GradientLoss, self).__init__()\n        self.loss_weight = loss_weight\n        self.reduction = reduction\n        if self.reduction not in ['none', 'mean', 'sum']:\n            raise ValueError(f'Unsupported reduction mode: {self.reduction}. '\n                             f'Supported ones are: {_reduction_modes}')\n\n    def forward(self, pred, target):\n        _, cin, _, _ = pred.shape\n        _, cout, _, _ = target.shape\n        assert cin == 3 and cout == 3\n        kx = torch.Tensor([[1, 0, -1], [2, 0, -2],\n                           [1, 0, -1]]).view(1, 1, 3, 3).to(target)\n        ky = torch.Tensor([[1, 2, 1], [0, 0, 0],\n                           [-1, -2, -1]]).view(1, 1, 3, 3).to(target)\n        kx = kx.repeat((3, 1, 1, 1))\n        ky = ky.repeat((3, 1, 1, 1))\n\n        pred_grad_x = F.conv2d(pred, kx, padding=1, groups=3)\n        pred_grad_y = F.conv2d(pred, ky, padding=1, groups=3)\n        target_grad_x = F.conv2d(target, kx, padding=1, groups=3)\n        target_grad_y = F.conv2d(target, ky, padding=1, groups=3)\n\n        loss = (\n            nn.L1Loss(reduction=self.reduction)(\n                pred_grad_x, target_grad_x) +\n            nn.L1Loss(reduction=self.reduction)(\n                pred_grad_y, target_grad_y))\n        return loss * self.loss_weight\n\n\nclass PoissonGradientLoss(nn.Module):\n    def __init__(self, reduction='mean'):\n        \"\"\"L_{grad} = \\frac{1}{2hw}\\sum_{m=1}^{H}\\sum_{n=1}{W}(\\partial f(I_{Blend}) - \n                       (\\partial f(I_{Source}) + \\partial f(I_{Target})))_{mn}^2\n\n           See **Deep Image Blending** for detail.\n        \"\"\"\n        super(PoissonGradientLoss, self).__init__()\n        self.reduction = reduction\n\n    def forward(self, source, target, blend, mask):\n        f = torch.Tensor([[0, -1, 0], [-1, 4, -1], [0, -1, 0]]).view(1, 1, 3, 3).to(target)\n        f = f.repeat((3, 1, 1, 1))\n        grad_s = F.conv2d(source, f, padding=1, groups=3) * mask\n        grad_t = F.conv2d(target, f, padding=1, groups=3) * (1 - mask)\n        grad_b = F.conv2d(blend, f, padding=1, groups=3)\n        return nn.MSELoss(reduction=self.reduction)(grad_b, (grad_t + grad_s))\n\n\nclass FusionModel(DistangleModel):\n    def name(self):\n        return 'fusion net cvpr 21'\n\n    @staticmethod\n    def modify_commandline_options(parser, is_train=True):\n\n        parser.set_defaults(pool_size=0, no_lsgan=True, norm='batch')\n        parser.set_defaults(dataset_mode='expo_param')\n        parser.add_argument('--wdataroot',default='None',  help='path to images (should have subfolders trainA, trainB, valA, valB, etc)')\n        parser.add_argument('--use_our_mask', action='store_true')\n        parser.add_argument('--mask_train',type=str,default=None)\n        parser.add_argument('--mask_test',type=str,default=None)\n        return parser\n\n    def initialize(self, opt):\n        BaseModel.initialize(self, opt)\n        self.isTrain = opt.isTrain\n        self.loss_names = ['G_param', 'alpha', 'rescontruction']\n        self.visual_names = ['input_img', 'litgt', 'alpha_pred', 'out', 'final', 'outgt']\n        self.model_names = ['G', 'M']\n        opt.output_nc = 3 \n\n        self.ks = ks = opt.ks\n        self.n = n = opt.n\n        self.shadow_loss = opt.shadow_loss\n        self.tv_loss = opt.tv_loss\n        self.grad_loss = opt.grad_loss\n        self.pgrad_loss = opt.pgrad_loss\n\n        self.netG = networks.define_G(4, 2 * 3, opt.ngf, 'RESNEXT', opt.norm,\n                                      not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)\n        self.netM = networks.define_G(1 + 3 + n * 3, ((1 + n) * 3) * 3 * ks * ks, opt.ngf, 'unet_256', opt.norm,\n                                      not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)\n\n        self.netG.to(self.device)\n        self.netM.to(self.device)\n        print(self.netG)\n        print(self.netM)\n        if self.isTrain:\n            # define loss functions\n            self.MSELoss = torch.nn.MSELoss()\n            self.criterionL1 = torch.nn.L1Loss()\n            self.bce = torch.nn.BCEWithLogitsLoss()\n            # initialize optimizers\n            self.optimizers = []\n\n            if opt.optimizer == 'adam':\n                self.optimizer_G = torch.optim.Adam(self.netG.parameters(),\n                                                    lr=opt.lr, betas=(opt.beta1, 0.999), weight_decay=1e-5)\n                self.optimizer_M = torch.optim.Adam(self.netM.parameters(),\n                                                    lr=opt.lr, betas=(opt.beta1, 0.999), weight_decay=1e-5)\n            elif opt.optimizer == 'sgd':\n                self.optimizer_G = torch.optim.SGD(self.netG.parameters(), momentum=0.9,\n                                                   lr=opt.lr, weight_decay=1e-5)\n                self.optimizer_M = torch.optim.SGD(self.netM.parameters(), momentum=0.9,\n                                                   lr=opt.lr, weight_decay=1e-5) \n            else:\n                assert False\n\n            self.optimizers.append(self.optimizer_G)\n            self.optimizers.append(self.optimizer_M)\n   \n    def set_input(self, input):\n        self.input_img = input['A'].to(self.device)\n        self.shadow_mask = input['B'].to(self.device)\n        self.imname = input['imname']\n        self.shadow_param = input['param'].to(self.device).type(torch.float)\n        self.shadow_mask = (self.shadow_mask > 0.9).type(torch.float) # * 2 - 1\n        self.nim = self.input_img.shape[1]\n        self.shadowfree_img = input['C'].to(self.device)\n        self.shadow_mask_3d = (self.shadow_mask > 0).type(torch.float).expand(self.input_img.shape)\n        self.shadow_mask_dilate = input['B_dilate'].to(self.device)\n        self.shadow_mask_erode = input['B_erode'].to(self.device)\n\n    def forward(self):\n        inputG = torch.cat([self.input_img, self.shadow_mask], 1)\n        shadow_param_pred = self.netG(inputG)\n\n        n = shadow_param_pred.shape[0]\n        w = inputG.shape[2]\n        h = inputG.shape[3]\n\n        addgt = self.shadow_param[:, [0, 2, 4]]\n        mulgt = self.shadow_param[:, [1, 3, 5]]\n        addgt = addgt.view(n, 3, 1, 1).expand((n, 3, w, h))\n        mulgt = mulgt.view(n, 3, 1, 1).expand((n, 3, w, h))\n\n        base_shadow_param_pred = shadow_param_pred[:, :2 * 3] # shadow_param_pred.view((n, self.n * 3, 2, 1, 1))\n        self.base_shadow_param_pred = base_shadow_param_pred\n\n        shadow_image = self.input_img.clone() / 2 + 0.5\n        base_shadow_output = shadow_image * base_shadow_param_pred[:, :3].view((n, 3, 1, 1)) + \\\n                             base_shadow_param_pred[:, 3:].view((n, 3, 1, 1))\n        shadow_output_list = []\n        for i in range(0, self.n - 1):\n            if i % 2 == 0:\n                scale = 1 + i * 0.01\n            else:\n                scale = 1 - i * 0.01\n            shadow_output_list.append(base_shadow_output * scale)\n        shadow_output = torch.cat([base_shadow_output] + shadow_output_list, dim=1)\n        self.lit = torch.cat([base_shadow_output] + shadow_output_list, dim=-1) * 2 - 1\n\n        shadow_output = shadow_output * 2 - 1\n        self.shadow_output = shadow_output\n\n        self.litgt = self.input_img.clone() / 2 + 0.5\n        self.litgt = (self.litgt * mulgt + addgt) * 2 - 1 # [-1, 1]\n\n        inputM = torch.cat([self.input_img, shadow_output, self.shadow_mask], 1)\n        out = torch.cat([self.input_img, shadow_output], 1)\n        out = out / 2 + 0.5\n        out_matrix = F.unfold(out, stride=1, padding=self.ks // 2, kernel_size=self.ks) # N, C x \\mul_(kernel_size), L\n\n        kernel = self.netM(inputM) # b, (3+1)*n * 3 * ks * ks, Tanh\n\n        b, c, h, w = self.input_img.shape\n        output = []\n        for i in range(b):\n            # feature = out[i, ...]\n            feature = out_matrix[i, ...] # ((1 + n) * 3) * ks * ks, L\n            weight = kernel[i, ...] # ((1 + n) * 3) * 3 * ks * ks, H, W\n            feature = feature.unsqueeze(0) # 1, C, L\n            weight = weight.view((3, (self.n + 1) * 3 * self.ks * self.ks, h * w))\n            weight = F.softmax(weight, dim=1)\n            iout = feature * weight # (3, C, L)\n            iout = torch.sum(iout, dim=1, keepdim=False)\n            iout = iout.view((1, 3, h, w))\n\n            output.append(iout)\n        self.final = torch.cat(output, dim=0) * 2 -1\n\n\n    def backward(self):\n        criterion = self.criterionL1\n        lambda_ = self.opt.lambda_L1\n\n        addgt = self.shadow_param[:, [0, 2, 4]] # [b, 3]\n        mulgt = self.shadow_param[:, [1, 3, 5]] # [b, 3]\n\n        loss_G_param_mul = self.MSELoss(self.base_shadow_param_pred[:, :3], mulgt) * lambda_\n        loss_G_param_add = self.MSELoss(self.base_shadow_param_pred[:, 3:], addgt) * lambda_\n        self.loss_G_param = (loss_G_param_add + loss_G_param_mul) / 2.0 * self.shadow_loss\n\n        if self.tv_loss > 0:\n            tv_loss = L_TV()(self.final - self.shadowfree_img) * lambda_ * self.tv_loss\n        else:\n            tv_loss = 0.0\n        \n        if self.grad_loss > 0:\n            grad_loss = GradientLoss()(self.final, self.shadowfree_img) * lambda_ * self.grad_loss\n        else:\n            grad_loss = 0.0\n\n        if self.pgrad_loss > 0:\n            pgrad_loss = PoissonGradientLoss()(target=self.input_img, blend=self.final,\n                                               source=self.shadowfree_img, mask=self.shadow_mask_dilate) \\\n                                               * lambda_ * self.pgrad_loss\n        else:\n            pgrad_loss = 0.0\n\n        self.loss_rescontruction = criterion(self.final, self.shadowfree_img) * lambda_\n        self.loss = self.loss_rescontruction + self.loss_G_param + tv_loss + grad_loss + pgrad_loss\n        self.loss.backward()\n\n    def optimize_parameters(self):\n        self.netM.zero_grad()\n        self.netG.zero_grad()\n        self.forward()\n        self.optimizer_G.zero_grad()\n        self.optimizer_M.zero_grad()\n        self.backward()\n        self.optimizer_G.step()\n        self.optimizer_M.step()\n    \n    def zero_grad(self):\n        self.netM.zero_grad()\n        self.netG.zero_grad()\n        self.optimizer_G.zero_grad()\n        self.optimizer_M.zero_grad()\n\n    def vis(self, e, s, path='', eval=False):\n        if len(path) > 0:\n            save_dir = os.path.join(self.save_dir, path)\n        else:\n            save_dir = self.save_dir\n        if not os.path.isdir(save_dir):\n            os.mkdir(save_dir)\n        shadow = self.input_img\n        output = self.final\n        gt = self.shadowfree_img\n        if eval:\n            img = self.final[0, ...]\n            filename = os.path.join(save_dir, self.imname[0])\n        else:\n            img = torch.cat([shadow, output, gt, self.litgt, self.lit], axis=-1)[0, ...]\n            filename = os.path.join(save_dir, \"epoch_%d_step_%d.png\" % (e, s))\n        img = tensor2im(img)\n\n        img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)\n        cv2.imwrite(filename, img)\n\n"
  },
  {
    "path": "models/Refine_model.py",
    "content": "import torch\nfrom collections import OrderedDict\nimport time\nimport numpy as np\nimport os\nimport torch.nn.functional as F\nimport torch.nn as nn\nfrom util.image_pool import ImagePool\nfrom .base_model import BaseModel\nfrom . import networks\nimport util.util as util\nfrom .distangle_model import DistangleModel\nfrom PIL import ImageOps, Image\nimport cv2\n\n\ndef tensor2im(input_image, imtype=np.uint8):\n    if not isinstance(input_image, np.ndarray):\n        if isinstance(input_image, torch.Tensor):  # get the data from a variable\n            image_tensor = input_image.data\n        else:\n            return input_image\n        image_numpy = image_tensor.cpu().float().numpy()  # convert it into a numpy array\n        if image_numpy.shape[0] == 1:  # grayscale to RGB\n            image_numpy = np.tile(image_numpy, (3, 1, 1))\n            # image_numpy = image_numpy.convert('L')\n        image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0  # post-processing: tranpose and scaling\n    else:  # if it is a numpy array, do nothing\n        image_numpy = input_image\n    return np.clip(image_numpy, 0, 255).astype(imtype)\n\n\nclass L_TV(nn.Module):\n    def __init__(self):\n        super(L_TV, self).__init__()\n\n    def forward(self, x):\n        _, _, h, w = x.size()\n        count_h = (h - 1) * w\n        count_w = (w - 1) * h\n\n        h_tv = torch.pow(x[:, :, 1:, :] - x[:, :, :h - 1, :], 2).sum()\n        w_tv = torch.pow(x[:, :, :, 1:] - x[:, :, :, :w - 1], 2).sum()\n        return (h_tv / count_h + w_tv / count_w) / 2.0\n\n\nclass GradientLoss(nn.Module):\n    def __init__(self, loss_weight=1.0, reduction='mean'):\n        super(GradientLoss, self).__init__()\n        self.loss_weight = loss_weight\n        self.reduction = reduction\n        if self.reduction not in ['none', 'mean', 'sum']:\n            raise ValueError(f'Unsupported reduction mode: {self.reduction}. '\n                             f'Supported ones are: {_reduction_modes}')\n\n    def forward(self, pred, target):\n        _, cin, _, _ = pred.shape\n        _, cout, _, _ = target.shape\n        assert cin == 3 and cout == 3\n        kx = torch.Tensor([[1, 0, -1], [2, 0, -2],\n                           [1, 0, -1]]).view(1, 1, 3, 3).to(target)\n        ky = torch.Tensor([[1, 2, 1], [0, 0, 0],\n                           [-1, -2, -1]]).view(1, 1, 3, 3).to(target)\n        kx = kx.repeat((cout, 1, 1, 1))\n        ky = ky.repeat((cout, 1, 1, 1))\n\n        pred_grad_x = F.conv2d(pred, kx, padding=1, groups=3)\n        pred_grad_y = F.conv2d(pred, ky, padding=1, groups=3)\n        target_grad_x = F.conv2d(target, kx, padding=1, groups=3)\n        target_grad_y = F.conv2d(target, ky, padding=1, groups=3)\n\n        loss = (\n            nn.L1Loss(reduction=self.reduction)\n            (pred_grad_x, target_grad_x) +\n            nn.L1Loss(reduction=self.reduction)\n            (pred_grad_y, target_grad_y))\n        return loss * self.loss_weight\n\n\nclass PoissonGradientLoss(nn.Module):\n    def __init__(self, reduction='mean'):\n        \"\"\"L_{grad} = \\frac{1}{2hw}\\sum_{m=1}^{H}\\sum_{n=1}{W}(\\partial f(I_{Blend}) - \n                       (\\partial f(I_{Source}) + \\partial f(I_{Target})))_{mn}^2\n           See **Deep Image Blending** for detail.\n        \"\"\"\n        super(PoissonGradientLoss, self).__init__()\n        self.reduction = reduction\n\n    def forward(self, source, target, blend, mask):\n        f = torch.Tensor([[0, -1, 0], [-1, 4, -1], [0, -1, 0]]).view(1, 1, 3, 3).to(target)\n        f = f.repeat((3, 1, 1, 1))\n        grad_s = F.conv2d(source, f, padding=1, groups=3) * mask\n        grad_t = F.conv2d(target, f, padding=1, groups=3) * (1 - mask)\n        grad_b = F.conv2d(blend, f, padding=1, groups=3)\n        return nn.MSELoss(reduction=self.reduction)(grad_b, (grad_t + grad_s))\n\n\nclass RefineModel(DistangleModel):\n    def name(self):\n        return 'auto exposure cvpr21'\n\n    @staticmethod\n    def modify_commandline_options(parser, is_train=True):\n\n        parser.set_defaults(pool_size=0, no_lsgan=True, norm='batch')\n        parser.set_defaults(dataset_mode='expo_param')\n        parser.add_argument('--wdataroot', default='None',\n                            help='path to images (should have subfolders trainA, trainB, valA, valB, etc)')\n        parser.add_argument('--use_our_mask', action='store_true')\n        parser.add_argument('--mask_train', type=str, default=None)\n        parser.add_argument('--mask_test', type=str, default=None)\n        return parser\n\n    def initialize(self, opt):\n        BaseModel.initialize(self, opt)\n        self.isTrain = opt.isTrain\n        self.loss_names = ['G_param', 'alpha', 'rescontruction']\n        self.visual_names = ['input_img', 'litgt', 'alpha_pred', 'out', 'final', 'outgt']\n        self.model_names = ['G', 'M', 'R']\n        # load/define networks\n        opt.output_nc = 3\n\n        self.ks = ks = opt.ks\n        self.rks = opt.rks\n        self.n = n = opt.n\n        self.shadow_loss = opt.shadow_loss\n        self.tv_loss = opt.tv_loss\n        self.grad_loss = opt.grad_loss\n        self.pgrad_loss = opt.pgrad_loss\n\n        self.netG = networks.define_G(4, 2 * 3, opt.ngf, 'RESNEXT', opt.norm,\n                                      not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)\n        self.netM = networks.define_G(1 + 3 + n * 3, ((1 + n) * 3) * 3 * ks * ks, opt.ngf, 'unet_256', opt.norm,\n                                      not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)\n        self.netR = networks.define_G(3 + 3 + 2, 3 * 3 * self.rks * self.rks, opt.ngf, 'unet_256', opt.norm,\n                                      not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)\n\n        self.netG.to(self.device)\n        self.netM.to(self.device)\n        self.netR.to(self.device)\n        self.netG.eval()\n        self.netM.eval()\n        self.netG.requires_grad = False\n        self.netM.requires_grad = False\n        if self.isTrain:\n            # define loss functions\n            self.MSELoss = torch.nn.MSELoss()\n            self.criterionL1 = torch.nn.L1Loss()\n            self.bce = torch.nn.BCEWithLogitsLoss()\n            # initialize optimizers\n            self.optimizers = []\n\n            if opt.optimizer == 'adam':\n                self.optimizer_R = torch.optim.Adam(self.netR.parameters(),\n                                                    lr=opt.lr, betas=(opt.beta1, 0.999), weight_decay=1e-5)\n            elif opt.optimizer == 'sgd':\n                self.optimizer_R = torch.optim.SGD(self.netR.parameters(), momentum=0.9,\n                                                   lr=opt.lr, weight_decay=1e-5)\n            else:\n                assert False\n\n            self.optimizers.append(self.optimizer_R)\n\n    def set_input(self, input):\n        self.input_img = input['A'].to(self.device)\n        self.shadow_mask = input['B'].to(self.device)\n        self.imname = input['imname']\n        if self.isTrain:\n            self.shadow_param = input['param'].to(self.device).type(torch.float)\n        self.shadow_mask = (self.shadow_mask > 0.9).type(torch.float)  # * 2 - 1\n        self.nim = self.input_img.shape[1]\n        self.shadowfree_img = input['C'].to(self.device)\n        self.shadow_mask_3d = (self.shadow_mask > 0).type(torch.float).expand(self.input_img.shape)\n        self.shadow_mask_dilate = input['B_dilate'].to(self.device)\n        self.shadow_mask_erode = input['B_erode'].to(self.device)\n\n    def forward(self):\n        inputG = torch.cat([self.input_img, self.shadow_mask], 1)\n        shadow_param_pred = self.netG(inputG)\n\n        n = shadow_param_pred.shape[0]\n        w = inputG.shape[2]\n        h = inputG.shape[3]\n\n        base_shadow_param_pred = shadow_param_pred[:, :2 * 3]  # shadow_param_pred.view((n, self.n * 3, 2, 1, 1))\n        self.base_shadow_param_pred = base_shadow_param_pred\n\n        shadow_image = self.input_img.clone() / 2 + 0.5\n        base_shadow_output = shadow_image * base_shadow_param_pred[:, :3].view((n, 3, 1, 1)) + \\\n                             base_shadow_param_pred[:, 3:].view((n, 3, 1, 1))\n        shadow_output_list = []\n        for i in range(0, self.n - 1):\n            if i % 2 == 0:\n                scale = 1 + i * 0.01\n            else:\n                scale = 1 - i * 0.01\n            shadow_output_list.append(base_shadow_output * scale)\n        shadow_output = torch.cat([base_shadow_output] + shadow_output_list, dim=1)\n        self.lit = torch.cat([base_shadow_output] + shadow_output_list, dim=-1) * 2 - 1\n\n        shadow_output = shadow_output * 2 - 1\n        self.shadow_output = shadow_output\n\n        inputM = torch.cat([self.input_img, shadow_output, self.shadow_mask], 1)\n        out = torch.cat([self.input_img, shadow_output], 1)\n        out = out / 2 + 0.5\n        out_matrix = F.unfold(out, stride=1, padding=self.ks // 2, kernel_size=self.ks)  # N, C x \\mul_(kernel_size), L\n\n        kernel = self.netM(inputM)  # b, (3+1)*n * 3 * ks * ks, Tanh\n\n        b, c, h, w = self.input_img.shape\n        output = []\n        for i in range(b):\n            # feature = out[i, ...]\n            feature = out_matrix[i, ...]  # ((1 + n) * 3) * ks * ks, L\n            weight = kernel[i, ...]  # ((1 + n) * 3) * 3 * ks * ks, H, W\n            feature = feature.unsqueeze(0)  # 1, C, L\n            weight = weight.view((3, (self.n + 1) * 3 * self.ks * self.ks, h * w))\n            weight = F.softmax(weight, dim=1)\n            iout = feature * weight  # (3, C, L)\n            iout = torch.sum(iout, dim=1, keepdim=False)\n            iout = iout.view((1, 3, h, w))\n            output.append(iout)\n        self.final = torch.cat(output, dim=0) * 2 - 1\n\n        final_output = self.final.detach()\n        final_output.requires_grad = False\n        rkernel = self.netR(torch.cat([self.input_img, final_output, self.shadow_mask,\n                                       self.shadow_mask_dilate - self.shadow_mask_erode], 1))\n\n        final_output = final_output / 2 + 0.5\n        rout_matrix = F.unfold(final_output, stride=1, padding=self.rks // 2, kernel_size=self.rks)\n        output = []\n        for i in range(b):\n            feature = rout_matrix[i, ...]  # ((1 + n) * 3) * ks * ks, L\n            weight = rkernel[i, ...]  # ((1 + n) * 3) * 3 * ks * ks, H, W\n            feature = feature.unsqueeze(0)  # 1, C, L\n            weight = weight.view((3, 3 * self.rks * self.rks, h * w))\n            iout = feature * weight  # (3, C, L)\n            iout = torch.sum(iout, dim=1, keepdim=False)\n            iout = iout.view((1, 3, h, w))\n\n            output.append(iout)\n        self.rfinal = torch.cat(output, dim=0) * 2 - 1\n\n    def backward(self):\n        # criterion = self.criterionL1\n        lambda_ = self.opt.lambda_L1\n\n        if self.tv_loss > 0:\n            tv_loss = L_TV()(self.rfinal - self.shadowfree_img) * lambda_ * self.tv_loss\n        else:\n            tv_loss = 0.0\n\n        if self.grad_loss > 0:\n            grad_loss = GradientLoss()(self.rfinal, self.shadowfree_img) * lambda_ * self.grad_loss\n        else:\n            grad_loss = 0.0\n\n        if self.pgrad_loss > 0:\n            pgrad_loss = PoissonGradientLoss()(target=self.input_img, blend=self.rfinal,\n                                               source=self.shadowfree_img, mask=self.shadow_mask_dilate) \\\n                         * lambda_ * self.pgrad_loss\n        else:\n            pgrad_loss = 0.0\n\n        self.loss_rescontruction = nn.L1Loss()(self.rfinal, self.shadowfree_img) * lambda_\n        self.loss = self.loss_rescontruction + tv_loss + grad_loss + pgrad_loss\n        self.loss.backward()\n\n    def optimize_parameters(self):\n        self.netM.zero_grad()\n        self.netG.zero_grad()\n        self.zero_grad()\n        self.forward()\n        self.optimizer_R.zero_grad()\n        self.backward()\n        self.optimizer_R.step()\n\n    def zero_grad(self):\n        self.netR.zero_grad()\n        self.optimizer_R.zero_grad()\n\n    def vis(self, e, s, path='', eval=False):\n        if len(path) > 0:\n            save_dir = os.path.join(self.save_dir, path)\n        else:\n            save_dir = self.save_dir\n        if not os.path.isdir(save_dir):\n            os.mkdir(save_dir)\n        shadow = self.input_img\n        ooutput = self.final\n        output = self.rfinal\n        gt = self.shadowfree_img\n\n        if not eval:\n            img = torch.cat([shadow, ooutput, output, gt], dim=-1)[0, ...]\n            filename = os.path.join(save_dir, self.imname[0])\n\n            img = tensor2im(img)\n            img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)\n            cv2.imwrite(filename, img)\n\n        else:\n            filename = os.path.join(save_dir, self.imname[0])\n            img = output[0, ...]\n            img = tensor2im(img)\n            img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)\n            cv2.imwrite(filename, img)\n\n            filename = os.path.join(save_dir, self.imname[0].replace('.png', '-o.png'))\n            img = ooutput[0, ...]\n            img = tensor2im(img)\n            img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)\n            cv2.imwrite(filename, img)\n"
  },
  {
    "path": "models/__init__.py",
    "content": "import importlib\nfrom models.base_model import BaseModel\n\n\ndef find_model_using_name(model_name):\n    # Given the option --model [modelname],\n    # the file \"models/modelname_model.py\"\n    # will be imported.\n    model_filename = \"models.\" + model_name + \"_model\"\n    modellib = importlib.import_module(model_filename)\n\n    # In the file, the class called ModelNameModel() will\n    # be instantiated. It has to be a subclass of BaseModel,\n    # and it is case-insensitive.\n    model = None\n    target_model_name = model_name.replace('_', '') + 'model'\n    for name, cls in modellib.__dict__.items():\n        if name.lower() == target_model_name.lower() \\\n           and issubclass(cls, BaseModel):\n            model = cls\n\n    if model is None:\n        print(\"In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase.\" % (model_filename, target_model_name))\n        exit(0)\n\n    return model\n\n\ndef get_option_setter(model_name):\n    model_class = find_model_using_name(model_name)\n    return model_class.modify_commandline_options\n\n\ndef create_model(opt):\n    model = find_model_using_name(opt.model)\n    instance = model()\n    instance.initialize(opt)\n    print(\"model [%s] was created\" % (instance.name()))\n    return instance\n"
  },
  {
    "path": "models/base_model.py",
    "content": "import os\nimport time\nimport torch\nfrom collections import OrderedDict\nfrom . import networks\nimport util.util as util\nimport numpy as np\nclass BaseModel():\n\n    # modify parser to add command line options,\n    # and also change the default values if needed\n    @staticmethod\n    def modify_commandline_options(parser, is_train):\n        return parser\n    \n\n    def train(self):\n        print('switching to training mode')\n        for name in self.model_names:\n            if isinstance(name, str):\n                net = getattr(self, 'net' + name)\n                net.train()\n    # make models eval mode during test time\n    def eval(self):\n        print('switching to testing mode')\n        for name in self.model_names:\n            if isinstance(name, str):\n                net = getattr(self, 'net' + name)\n                net.eval()\n\n    def name(self):\n        return 'BaseModel'\n\n    def initialize(self, opt):\n        self.epoch=0\n        self.opt = opt\n        self.gpu_ids = opt.gpu_ids\n        self.isTrain = opt.isTrain\n        self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu')\n        self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)\n        if not os.path.isdir(self.save_dir):\n            os.mkdir(self.save_dir)\n        # if opt.resize_or_crop != 'scale_width':\n        #     torch.backends.cudnn.benchmark = True\n        self.loss_names = []\n        self.model_names = []\n        self.visual_names = []\n        self.image_paths = []\n\n    #def set_input(self, input):\n    #    self.input = input\n\n\n    def set_input(self, input):\n        self.input_img = input['A'].to(self.device)\n        self.shadow_mask = input['B'].to(self.device)\n        self.shadow_mask = (self.shadow_mask>0.9).type(torch.float)*2-1\n        #self.shadow_mask = (self.shadow_mask==1).type(torch.float)*2-1\n        self.nim = self.input_img.shape[1]\n        self.shadowfree_img = input['C'].to(self.device)\n        self.shadow_mask_3d= (self.shadow_mask>0).type(torch.float).expand(self.input_img.shape)   \n        #self.shadow_mask_3d_over = (self.shadow_mask_over>0).type(torch.float).expand(self.input_img.shape)\n\n    def get_prediction(self,input):\n        self.input_img = input['A'].to(self.device)\n        self.shadow_mask = input['B'].to(self.device)\n        self.shadow_mask = (self.shadow_mask>0.9).type(torch.float)*2-1\n        self.shadow_mask_3d= (self.shadow_mask>0).type(torch.float).expand(self.input_img.shape)   \n        inputG = torch.cat([self.input_img,self.shadow_mask],1)\n        out = self.netG(inputG)\n        return util.tensor2im(out)\n\n    def forward(self):\n        pass\n\n    # load and print networks; create schedulers\n    def setup(self, opt, parser=None):\n        print(self.name)\n        if self.isTrain:\n            self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers]\n\n        if not self.isTrain or opt.continue_train or opt.finetuning:\n            print(\"LOADING %s\"%(self.name))\n            self.load_networks(opt.epoch)\n        self.print_networks(opt.verbose)\n\n\n    # used in test time, wrapping `forward` in no_grad() so we don't save\n    # intermediate steps for backprop\n    def test(self):\n        with torch.no_grad():\n            self.forward()\n\n    # get image paths\n    def get_image_paths(self):\n        return self.image_paths\n\n    def optimize_parameters(self):\n        pass\n\n    # update learning rate (called once every epoch)\n    def update_learning_rate(self,loss=None):\n        for scheduler in self.schedulers:\n            if not loss:\n                scheduler.step()\n            else:\n                scheduler.step(loss)\n\n        lr = self.optimizers[0].param_groups[0]['lr']\n        print('learning rate = %.7f' % lr)\n\n    # return visualization images. train.py will display these images, and save the images to a html\n    def get_current_visuals(self):\n        t= time.time()\n        nim = self.shadow.shape[0]\n        visual_ret = OrderedDict()\n        all =[]\n        for i in range(0,min(nim-1,5)):\n            row=[]\n            for name in self.visual_names:\n                if isinstance(name, str):\n                    if hasattr(self,name):\n                        im = util.tensor2im(getattr(self, name).data[i:i+1,:,:,:])\n                        row.append(im)\n            row=tuple(row)\n            all.append(np.hstack(row))\n        all = tuple(all)\n        \n        allim = np.vstack(all)\n        return OrderedDict([(self.opt.name,allim)])\n    \n    # return traning losses/errors. train.py will print out these errors as debugging information\n    def get_current_losses(self):\n        errors_ret = OrderedDict()\n        for name in self.loss_names:\n            if isinstance(name, str):\n                # float(...) works for both scalar tensor and float number\n                if hasattr(self, 'loss_' + name):\n                    errors_ret[name] = float(getattr(self, 'loss_' + name))\n        return errors_ret\n\n    # save models to the disk\n    def save_networks(self, epoch):\n        for name in self.model_names:\n            if isinstance(name, str):\n                save_filename = '%s_net_%s.pth' % (epoch, name)\n                save_path = os.path.join(self.save_dir, save_filename)\n                net = getattr(self, 'net' + name)\n\n                if len(self.gpu_ids) > 0 and torch.cuda.is_available():\n                    torch.save(net.module.cpu().state_dict(), save_path)\n                    net.cuda(self.gpu_ids[0])\n                else:\n                    torch.save(net.cpu().state_dict(), save_path)\n\n    def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):\n        key = keys[i]\n        if i + 1 == len(keys):  # at the end, pointing to a parameter/buffer\n            if module.__class__.__name__.startswith('InstanceNorm') and \\\n                    (key == 'running_mean' or key == 'running_var'):\n                if getattr(module, key) is None:\n                    state_dict.pop('.'.join(keys))\n            if module.__class__.__name__.startswith('InstanceNorm') and \\\n               (key == 'num_batches_tracked'):\n                state_dict.pop('.'.join(keys))\n        else:\n            self.__patch_instance_norm_state_dict(state_dict, getattr(module,key), keys, i + 1)\n\n    # load models from the disk\n    def load_networks(self, epoch, save_dir=None):\n        print(epoch)\n        if save_dir is None:\n            save_dir = self.save_dir\n        \n        for name in self.model_names:\n            if isinstance(name, str):\n                load_filename = '%s_net_%s.pth' % (epoch, name)\n                load_path = os.path.join(save_dir, load_filename)\n                if self.opt.finetuning:\n\n                    load_filename = '%s_net_%s.pth' % (self.opt.finetuning_epoch, name)\n                    load_path = os.path.join(self.opt.finetuning_dir, load_filename)\n\n                net = getattr(self, 'net' + name)\n                if isinstance(net, torch.nn.DataParallel):\n                    net = net.module\n                print('loading the model from %s' % load_path)\n                # if you are using PyTorch newer than 0.4 (e.g., built from\n                # GitHub source), you can remove str() on self.device\n                state_dict = torch.load(load_path, map_location=str(self.device))\n                #\n                if hasattr(state_dict, '_metadata'):\n                    del state_dict._metadata\n\n                # patch InstanceNorm checkpoints prior to 0.4\n                for key in list(state_dict.keys()):  # need to copy keys here because we mutate in loop\n                    self.__patch_instance_norm_state_dict(state_dict, net, key.split('.'))\n                net.load_state_dict(state_dict)\n\n    # print network information\n    def print_networks(self, verbose):\n        print('---------- Networks initialized -------------')\n        for name in self.model_names:\n            if isinstance(name, str):\n                net = getattr(self, 'net' + name)\n                num_params = 0\n                for param in net.parameters():\n                    num_params += param.numel()\n                print(net)\n                print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6))\n        print('-----------------------------------------------')\n\n    # set requies_grad=Fasle to avoid computation\n    def set_requires_grad(self, nets, requires_grad=False):\n        if not isinstance(nets, list):\n            nets = [nets]\n        for net in nets:\n            if net is not None:\n                for param in net.parameters():\n                    param.requires_grad = requires_grad\n"
  },
  {
    "path": "models/distangle_model.py",
    "content": "import torch\nfrom util.image_pool import ImagePool\nfrom .base_model import BaseModel\nfrom . import networks\nimport util.util as util\n\n\n\nclass DistangleModel(BaseModel):\n    def name(self):\n        return 'DistangleModel'\n\n    @staticmethod\n    def modify_commandline_options(parser, is_train=True):\n\n        parser.set_defaults(pool_size=0, no_lsgan=True, norm='batch')\n        parser.set_defaults(dataset_mode='expo_param')\n        parser.set_defaults(netG='RESNEXT')\n        if is_train:\n            parser.add_argument('--lambda_L1', type=float, default=100.0, help='weight for L1 loss')\n\n        return parser\n\n    def initialize(self, opt):\n        BaseModel.initialize(self, opt)\n        self.isTrain = opt.isTrain\n        self.loss_names = ['G']\n        # specify the images you want to save/display. The program will call base_model.get_current_visuals\n        self.visual_names = ['input_img', 'shadow_mask','out','outgt']\n        # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks\n        if self.isTrain:\n            self.model_names = ['G']\n        else:  # during test time, only load Gs\n            self.model_names = ['G']\n        # load/define networks\n        opt.output_nc= 3 if opt.task=='sr' else 1 #3 for shadow removal, 1 for detection\n        self.netG = networks.define_G(4, opt.output_nc, opt.ngf, 'RESNEXT', opt.norm,\n                                      not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)\n        self.netG.to(self.device)\n        print(self.netG)\n        if self.isTrain:\n            # define loss functions\n            self.criterionL1 = torch.nn.L1Loss()\n            self.bce = torch.nn.BCEWithLogitsLoss()\n            # initialize optimizers\n            self.optimizers = []\n            self.optimizer_G = torch.optim.Adam(self.netG.parameters(),\n                                                lr=opt.lr, betas=(opt.beta1, 0.999),weight_decay=1e-5)\n            self.optimizers.append(self.optimizer_G)\n\n    def set_input(self, input):\n        self.input_img = input['A'].to(self.device)\n        self.shadow_mask = input['B'].to(self.device)\n        self.shadow_param = input['param'].to(self.device).type(torch.float)\n        self.shadow_mask = (self.shadow_mask>0.9).type(torch.float)*2-1\n        self.nim = self.input_img.shape[1]\n        self.shadowfree_img = input['C'].to(self.device)\n        self.shadow_mask_3d= (self.shadow_mask>0).type(torch.float).expand(self.input_img.shape)   \n\n    def forward(self):\n        inputG = torch.cat([self.input_img, self.shadow_mask], 1)\n        self.Gout = self.netG(inputG)\n        self.lit = self.input_img.clone() / 2 + 0.5\n        n = self.Gout.shape[0]\n\n        add = add.view(n, 3, 1, 1).expand((n, 3, 256, 256))\n        mul = mul.view(n, 3, 1, 1).expand((n, 3, 256, 256))\n\n        self.litgt = (self.input_img.clone() + 1) / 2\n        self.lit = self.lit * mul + add\n        self.litgt = self.litgt * mulgt + addgt\n        self.out = (self.input_img / 2 + 0.5) * (1 - self.shadow_mask_3d) + self.lit * self.shadow_mask_3d\n        self.out = self.out * 2 - 1\n        self.outgt = (self.input_img / 2 + 0.5) * (1 - self.shadow_mask_3d) + self.litgt * self.shadow_mask_3d\n        self.outgt = self.outgt * 2 - 1\n        self.alpha = torch.mean(self.shadowfree_img / self.lit,dim=1,keepdim=True)\n\n\n    def get_prediction(self,input):\n        self.input_img = input['A'].to(self.device)\n        self.shadow_mask = input['B'].to(self.device)\n        inputG = torch.cat([self.input_img,self.shadow_mask],1)\n        self.shadow_mask = (self.shadow_mask>0.9).type(torch.float)*2-1\n        self.shadow_mask_3d= (self.shadow_mask>0).type(torch.float).expand(self.input_img.shape)   \n        self.Gout = self.netG(inputG)\n        self.lit = self.input_img.clone()/2+0.5\n        add = self.Gout[:,[0,2,4]]\n        mul = self.Gout[:,[1,3,5]]\n        n = self.Gout.shape[0]\n        add = add.view(n,3,1,1).expand((n,3,256,256))\n        mul = mul.view(n,3,1,1).expand((n,3,256,256))\n        self.lit = self.lit*mul + add\n        self.out = (self.input_img/2+0.5)*(1-self.shadow_mask_3d) + self.lit*self.shadow_mask_3d\n        self.out = self.out*2-1\n        return util.tensor2im(self.out,scale =0) \n\n    def backward_G(self):\n        criterion = self.criterionL1 if self.opt.task =='sr' else self.bce\n        lambda_ = self.opt.lambda_L1 if self.opt.task =='sr' else 1\n        self.loss_G = criterion(self.Gout, self.shadow_param) * lambda_\n        self.loss_G.backward()\n\n    def optimize_parameters(self):\n        self.forward()\n        self.optimizer_G.zero_grad()\n        self.backward_G()\n        self.optimizer_G.step()\n\n\nif __name__=='__main__':\n    parser = argparse.ArgumentParser()\n"
  },
  {
    "path": "models/loss_function.py",
    "content": "import torch\nfrom torch import nn\nimport torch.nn.functional as F\n\ndef smooth_loss(pred_map):\n    def gradient(pred):\n        D_dy = pred[:, :, 1:] - pred[:, :, :-1]\n        D_dx = pred[:, :, :, 1:] - pred[:, :, :, :-1]\n        return D_dx, D_dy\n\n    loss = 0\n    weight = 1.\n\n    dx, dy = gradient(pred_map)\n    dx2, dxdy = gradient(dx)\n    dydx, dy2 = gradient(dy)\n    loss += (dx2.abs().mean() + dxdy.abs().mean() + dydx.abs().mean() + dy2.abs().mean())*weight\n    return loss\n\n"
  },
  {
    "path": "models/networks.py",
    "content": "import torch\nimport torch.nn as nn\nfrom torch.nn import init\nimport functools\nfrom torch.optim import lr_scheduler\n###############################################################################\n# Helper Functions\n###############################################################################\n\n\ndef get_norm_layer(norm_type='instance'):\n    if norm_type == 'batch':\n        norm_layer = functools.partial(nn.BatchNorm2d, affine=True)\n    elif norm_type == 'instance':\n        norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)\n    elif norm_type == 'none':\n        norm_layer = None\n    else:\n        raise NotImplementedError('normalization layer [%s] is not found' % norm_type)\n    return norm_layer\n\n\ndef get_scheduler(optimizer, opt):\n    if opt.lr_policy == 'lambda':\n        def lambda_rule(epoch):\n            lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1)\n            return lr_l\n        scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)\n    elif opt.lr_policy == 'step':\n        scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)\n    elif opt.lr_policy == 'shadow_step':\n        scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[70000,90000,13200], gamma=0.3)\n    elif opt.lr_policy == 'plateau':\n        scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)\n    elif opt.lr_policy == 'cosine':\n        scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.niter, eta_min=0)\n    else:\n        return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)\n    return scheduler\n\n\ndef init_weights(net, init_type='normal', gain=0.02):\n    def init_func(m):\n        classname = m.__class__.__name__\n        if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):\n            if init_type == 'normal':\n                init.normal_(m.weight.data, 0.0, gain)\n            elif init_type == 'xavier':\n                init.xavier_normal_(m.weight.data, gain=gain)\n            elif init_type == 'kaiming':\n                init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')\n            elif init_type == 'orthogonal':\n                init.orthogonal_(m.weight.data, gain=gain)\n            else:\n                raise NotImplementedError('initialization method [%s] is not implemented' % init_type)\n            if hasattr(m, 'bias') and m.bias is not None:\n                init.constant_(m.bias.data, 0.0)\n        elif classname.find('BatchNorm2d') != -1:\n            init.normal_(m.weight.data, 1.0, gain)\n            init.constant_(m.bias.data, 0.0)\n\n    print('initialize network with %s' % init_type)\n    net.apply(init_func)\n\n\ndef init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]):\n    if len(gpu_ids) > 0:\n        assert(torch.cuda.is_available())\n        net.to(gpu_ids[0])\n        net = torch.nn.DataParallel(net, gpu_ids)\n    init_weights(net, init_type, gain=init_gain)\n    return net\n\ndef define_vgg(num_input,num_classes,init_type='normal', init_gain=0.02, gpu_ids=[]):\n    print(gpu_ids)\n    from .vgg import create_vgg\n    net = create_vgg(num_input,num_classes)\n    net.to(gpu_ids[0])\n    net = torch.nn.DataParallel(net,gpu_ids)\n    init_weights(net,init_type,gain=init_gain)\n    return net\n\n\n\ndef define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, init_type='normal', init_gain=0.02, gpu_ids=[]):\n    net = None\n    norm_layer = get_norm_layer(norm_type=norm)\n\n    if netG == 'resnet_9blocks':\n        net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9)\n    elif netG == 'resnet_6blocks':\n        net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=6)\n    elif netG == 'unet_128':\n        net = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer=norm_layer, use_dropout=use_dropout)\n    elif netG == 'unet_256':\n        net = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout)\n    elif netG == 'unet_32':\n        net = UnetGenerator(input_nc, output_nc, 5, ngf, norm_layer=norm_layer, use_dropout=use_dropout)\n    elif netG == 'RESNEXT':\n        from .resnet import resnext101_32x8d\n        net = resnext101_32x8d(pretrained=False, num_classes=output_nc, num_inputchannels=input_nc)\n        if len(gpu_ids)>0:\n            assert(torch.cuda.is_available())\n            net.to(gpu_ids[0])\n            net = torch.nn.DataParallel(net,gpu_ids)\n        return net\n    elif netG == 'resnet50':\n        from .resnet import resnet50\n        net = resnet50(pretrained=False, num_classes=output_nc, num_inputchannels=input_nc) \n    else:\n        raise NotImplementedError('Generator model name [%s] is not recognized' % netG)\n    return init_net(net, init_type, init_gain, gpu_ids)\n\n\ndef define_D(input_nc, ndf, netD,\n             n_layers_D=3, norm='batch', use_sigmoid=False, init_type='normal', init_gain=0.02, gpu_ids=[]):\n    net = None\n    norm_layer = get_norm_layer(norm_type=norm)\n\n    if netD == 'basic':\n        net = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer, use_sigmoid=use_sigmoid)\n    elif netD == 'n_layers':\n        net = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer, use_sigmoid=use_sigmoid)\n    elif netD == 'pixel':\n        net = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer, use_sigmoid=use_sigmoid)\n    elif netD == 'unet_32':\n        net = UnetGenerator(input_nc, output_nc, 5, ngf, norm_layer=norm_layer, use_dropout=use_dropout)\n    else:\n        raise NotImplementedError('Discriminator model name [%s] is not recognized' % net)\n    return init_net(net, init_type, init_gain, gpu_ids)\n\n\n##############################################################################\n# Classes\n##############################################################################\n\n\n# Defines the GAN loss which uses either LSGAN or the regular GAN.\n# When LSGAN is used, it is basically same as MSELoss,\n# but it abstracts away the need to create the target label tensor\n# that has the same size as the input\nclass GANLoss(nn.Module):\n    def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0):\n        super(GANLoss, self).__init__()\n        self.register_buffer('real_label', torch.tensor(target_real_label))\n        self.register_buffer('fake_label', torch.tensor(target_fake_label))\n        if use_lsgan:\n            self.loss = nn.MSELoss()\n        else:\n            self.loss = nn.BCELoss()\n\n    def get_target_tensor(self, input, target_is_real):\n        if target_is_real:\n            target_tensor = self.real_label\n        else:\n            target_tensor = self.fake_label\n        return target_tensor.expand_as(input)\n\n    def __call__(self, input, target_is_real):\n        target_tensor = self.get_target_tensor(input, target_is_real)\n        return self.loss(input, target_tensor)\n\n\n# Defines the generator that consists of Resnet blocks between a few\n# downsampling/upsampling operations.\n# Code and idea originally from Justin Johnson's architecture.\n# https://github.com/jcjohnson/fast-neural-style/\nclass ResnetGenerator(nn.Module):\n    def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect'):\n        assert(n_blocks >= 0)\n        super(ResnetGenerator, self).__init__()\n        self.input_nc = input_nc\n        self.output_nc = output_nc\n        self.ngf = ngf\n        if type(norm_layer) == functools.partial:\n            use_bias = norm_layer.func == nn.InstanceNorm2d\n        else:\n            use_bias = norm_layer == nn.InstanceNorm2d\n\n        model = [nn.ReflectionPad2d(3),\n                 nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0,\n                           bias=use_bias),\n                 norm_layer(ngf),\n                 nn.ReLU(True)]\n\n        n_downsampling = 2\n        for i in range(n_downsampling):\n            mult = 2**i\n            model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3,\n                                stride=2, padding=1, bias=use_bias),\n                      norm_layer(ngf * mult * 2),\n                      nn.ReLU(True)]\n\n        mult = 2**n_downsampling\n        for i in range(n_blocks):\n            model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]\n\n        for i in range(n_downsampling):\n            mult = 2**(n_downsampling - i)\n            model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),\n                                         kernel_size=3, stride=2,\n                                         padding=1, output_padding=1,\n                                         bias=use_bias),\n                      norm_layer(int(ngf * mult / 2)),\n                      nn.ReLU(True)]\n        model += [nn.ReflectionPad2d(3)]\n        model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]\n        model += [nn.Tanh()]\n\n        self.model = nn.Sequential(*model)\n\n    def forward(self, input):\n        return self.model(input)\n\n\n# Define a resnet block\nclass ResnetBlock(nn.Module):\n    def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):\n        super(ResnetBlock, self).__init__()\n        self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias)\n\n    def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):\n        conv_block = []\n        p = 0\n        if padding_type == 'reflect':\n            conv_block += [nn.ReflectionPad2d(1)]\n        elif padding_type == 'replicate':\n            conv_block += [nn.ReplicationPad2d(1)]\n        elif padding_type == 'zero':\n            p = 1\n        else:\n            raise NotImplementedError('padding [%s] is not implemented' % padding_type)\n\n        conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),\n                       norm_layer(dim),\n                       nn.ReLU(True)]\n        if use_dropout:\n            conv_block += [nn.Dropout(0.5)]\n\n        p = 0\n        if padding_type == 'reflect':\n            conv_block += [nn.ReflectionPad2d(1)]\n        elif padding_type == 'replicate':\n            conv_block += [nn.ReplicationPad2d(1)]\n        elif padding_type == 'zero':\n            p = 1\n        else:\n            raise NotImplementedError('padding [%s] is not implemented' % padding_type)\n        conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),\n                       norm_layer(dim)]\n\n        return nn.Sequential(*conv_block)\n\n    def forward(self, x):\n        out = x + self.conv_block(x)\n        return out\n\n\n# Defines the Unet generator.\n# |num_downs|: number of downsamplings in UNet. For example,\n# if |num_downs| == 7, image of size 128x128 will become of size 1x1\n# at the bottleneck\nclass UnetGenerator(nn.Module):\n    def __init__(self, input_nc, output_nc, num_downs, ngf=64,\n                 norm_layer=nn.BatchNorm2d, use_dropout=False):\n        super(UnetGenerator, self).__init__()\n\n        # construct unet structure\n        unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True)\n        for i in range(num_downs - 5):\n            unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)\n        unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)\n        unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)\n        unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)\n        unet_block = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer)\n\n        self.model = unet_block\n\n    def forward(self, input):\n        return self.model(input)\n\n\n# Defines the submodule with skip connection.\n# X -------------------identity---------------------- X\n#   |-- downsampling -- |submodule| -- upsampling --|\nclass UnetSkipConnectionBlock(nn.Module):\n    def __init__(self, outer_nc, inner_nc, input_nc=None,\n                 submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):\n        super(UnetSkipConnectionBlock, self).__init__()\n        self.outermost = outermost\n        if type(norm_layer) == functools.partial:\n            use_bias = norm_layer.func == nn.InstanceNorm2d\n        else:\n            use_bias = norm_layer == nn.InstanceNorm2d\n        if input_nc is None:\n            input_nc = outer_nc\n        downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,\n                             stride=2, padding=1, bias=use_bias)\n        downrelu = nn.LeakyReLU(0.2, True)\n        downnorm = norm_layer(inner_nc)\n        uprelu = nn.ReLU(True)\n        upnorm = norm_layer(outer_nc)\n\n        if outermost:\n            upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,\n                                        kernel_size=4, stride=2,\n                                        padding=1)\n            down = [downconv]\n            up = [uprelu, upconv] # , nn.Tanh()]\n            model = down + [submodule] + up\n        elif innermost:\n            upconv = nn.ConvTranspose2d(inner_nc, outer_nc,\n                                        kernel_size=4, stride=2,\n                                        padding=1, bias=use_bias)\n            down = [downrelu, downconv]\n            up = [uprelu, upconv, upnorm]\n            model = down + up\n        else:\n            upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,\n                                        kernel_size=4, stride=2,\n                                        padding=1, bias=use_bias)\n            down = [downrelu, downconv, downnorm]\n            up = [uprelu, upconv, upnorm]\n\n            if use_dropout:\n                model = down + [submodule] + up + [nn.Dropout(0.5)]\n            else:\n                model = down + [submodule] + up\n\n        self.model = nn.Sequential(*model)\n\n    def forward(self, x):\n        if self.outermost:\n            return self.model(x)\n        else:\n            return torch.cat([x, self.model(x)], 1)\n\n\n# Defines the PatchGAN discriminator with the specified arguments.\nclass NLayerDiscriminator(nn.Module):\n    def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False):\n        super(NLayerDiscriminator, self).__init__()\n        if type(norm_layer) == functools.partial:\n            use_bias = norm_layer.func == nn.InstanceNorm2d\n        else:\n            use_bias = norm_layer == nn.InstanceNorm2d\n\n        kw = 4\n        padw = 1\n        sequence = [\n            nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),\n            nn.LeakyReLU(0.2, True)\n        ]\n\n        nf_mult = 1\n        nf_mult_prev = 1\n        for n in range(1, n_layers):\n            nf_mult_prev = nf_mult\n            nf_mult = min(2**n, 8)\n            sequence += [\n                nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,\n                          kernel_size=kw, stride=2, padding=padw, bias=use_bias),\n                norm_layer(ndf * nf_mult),\n                nn.LeakyReLU(0.2, True)\n            ]\n\n        nf_mult_prev = nf_mult\n        nf_mult = min(2**n_layers, 8)\n        sequence += [\n            nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,\n                      kernel_size=kw, stride=1, padding=padw, bias=use_bias),\n            norm_layer(ndf * nf_mult),\n            nn.LeakyReLU(0.2, True)\n        ]\n\n        sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]\n\n        if use_sigmoid:\n            sequence += [nn.Sigmoid()]\n\n        self.model = nn.Sequential(*sequence)\n\n    def forward(self, input):\n        return self.model(input)\n\n\nclass PixelDiscriminator(nn.Module):\n    def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d, use_sigmoid=False):\n        super(PixelDiscriminator, self).__init__()\n        if type(norm_layer) == functools.partial:\n            use_bias = norm_layer.func == nn.InstanceNorm2d\n        else:\n            use_bias = norm_layer == nn.InstanceNorm2d\n\n        self.net = [\n            nn.Conv2d(input_nc, ndf, kernel_size=1, stride=1, padding=0),\n            nn.LeakyReLU(0.2, True),\n            nn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=use_bias),\n            norm_layer(ndf * 2),\n            nn.LeakyReLU(0.2, True),\n            nn.Conv2d(ndf * 2, 1, kernel_size=1, stride=1, padding=0, bias=use_bias)]\n\n        if use_sigmoid:\n            self.net.append(nn.Sigmoid())\n\n        self.net = nn.Sequential(*self.net)\n\n    def forward(self, input):\n        return self.net(input)\n"
  },
  {
    "path": "models/resnet.py",
    "content": "import torch\nimport torch.nn as nn\n#from .utils import load_state_dict_from_url\ntry:\n    from torch.hub import load_state_dict_from_url\nexcept ImportError:\n    from torch.utils.model_zoo import load_url as load_state_dict_from_url\n\n__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',\n           'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',\n           'wide_resnet50_2', 'wide_resnet101_2']\n\n\nmodel_urls = {\n    'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',\n    'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',\n    'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',\n    'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',\n    'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',\n    'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',\n    'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',\n    'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',\n    'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',\n}\n\n\ndef conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):\n    \"\"\"3x3 convolution with padding\"\"\"\n    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,\n                     padding=dilation, groups=groups, bias=False, dilation=dilation)\n\n\ndef conv1x1(in_planes, out_planes, stride=1):\n    \"\"\"1x1 convolution\"\"\"\n    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)\n\n\nclass BasicBlock(nn.Module):\n    expansion = 1\n\n    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,\n                 base_width=64, dilation=1, norm_layer=None):\n        super(BasicBlock, self).__init__()\n        if norm_layer is None:\n            norm_layer = nn.BatchNorm2d\n        if groups != 1 or base_width != 64:\n            raise ValueError('BasicBlock only supports groups=1 and base_width=64')\n        if dilation > 1:\n            raise NotImplementedError(\"Dilation > 1 not supported in BasicBlock\")\n        # Both self.conv1 and self.downsample layers downsample the input when stride != 1\n        self.conv1 = conv3x3(inplanes, planes, stride)\n        self.bn1 = norm_layer(planes)\n        self.relu = nn.ReLU(inplace=True)\n        self.conv2 = conv3x3(planes, planes)\n        self.bn2 = norm_layer(planes)\n        self.downsample = downsample\n        self.stride = stride\n\n    def forward(self, x):\n        identity = x\n\n        out = self.conv1(x)\n        out = self.bn1(out)\n        out = self.relu(out)\n\n        out = self.conv2(out)\n        out = self.bn2(out)\n\n        if self.downsample is not None:\n            identity = self.downsample(x)\n\n        out += identity\n        out = self.relu(out)\n\n        return out\n\n\nclass Bottleneck(nn.Module):\n    # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)\n    # while original implementation places the stride at the first 1x1 convolution(self.conv1)\n    # according to \"Deep residual learning for image recognition\"https://arxiv.org/abs/1512.03385.\n    # This variant is also known as ResNet V1.5 and improves accuracy according to\n    # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.\n\n    expansion = 4\n\n    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,\n                 base_width=64, dilation=1, norm_layer=None):\n        super(Bottleneck, self).__init__()\n        if norm_layer is None:\n            norm_layer = nn.BatchNorm2d\n        width = int(planes * (base_width / 64.)) * groups\n        # Both self.conv2 and self.downsample layers downsample the input when stride != 1\n        self.conv1 = conv1x1(inplanes, width)\n        self.bn1 = norm_layer(width)\n        self.conv2 = conv3x3(width, width, stride, groups, dilation)\n        self.bn2 = norm_layer(width)\n        self.conv3 = conv1x1(width, planes * self.expansion)\n        self.bn3 = norm_layer(planes * self.expansion)\n        self.relu = nn.ReLU(inplace=True)\n        self.downsample = downsample\n        self.stride = stride\n\n    def forward(self, x):\n        identity = x\n\n        out = self.conv1(x)\n        out = self.bn1(out)\n        out = self.relu(out)\n\n        out = self.conv2(out)\n        out = self.bn2(out)\n        out = self.relu(out)\n\n        out = self.conv3(out)\n        out = self.bn3(out)\n\n        if self.downsample is not None:\n            identity = self.downsample(x)\n\n        out += identity\n        out = self.relu(out)\n\n        return out\n\n\nclass ResNet(nn.Module):\n    def my_load_state_dict(self,state_dict):\n        own_state = self.state_dict()\n        for name,param in state_dict.items():\n            if name not in own_state:\n                continue\n            if isinstance(param, torch.Tensor):\n                param = param.data\n                try:\n                    own_state[name].copy_(param)\n                except:\n                    if len(own_state[name].shape)>=2:\n                        torch.nn.init.xavier_normal_(own_state[name])\n                    continue\n\n    def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,\n                 groups=1, width_per_group=64, replace_stride_with_dilation=None,\n                 norm_layer=None,num_inputchannels=3):\n        super(ResNet, self).__init__()\n        if norm_layer is None:\n            norm_layer = nn.BatchNorm2d\n        self._norm_layer = norm_layer\n\n        self.inplanes = 64\n        self.dilation = 1\n        if replace_stride_with_dilation is None:\n            # each element in the tuple indicates if we should replace\n            # the 2x2 stride with a dilated convolution instead\n            replace_stride_with_dilation = [False, False, False]\n        if len(replace_stride_with_dilation) != 3:\n            raise ValueError(\"replace_stride_with_dilation should be None \"\n                             \"or a 3-element tuple, got {}\".format(replace_stride_with_dilation))\n        self.groups = groups\n        self.base_width = width_per_group\n        self.conv1 = nn.Conv2d(num_inputchannels, self.inplanes, kernel_size=7, stride=2, padding=3,\n                               bias=False)\n        self.bn1 = norm_layer(self.inplanes)\n        self.relu = nn.ReLU(inplace=True)\n        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)\n        self.layer1 = self._make_layer(block, 64, layers[0])\n        self.layer2 = self._make_layer(block, 128, layers[1], stride=2,\n                                       dilate=replace_stride_with_dilation[0])\n        self.layer3 = self._make_layer(block, 256, layers[2], stride=2,\n                                       dilate=replace_stride_with_dilation[1])\n        self.layer4 = self._make_layer(block, 512, layers[3], stride=2,\n                                       dilate=replace_stride_with_dilation[2])\n        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))\n        self.fc = nn.Linear(512 * block.expansion, num_classes)\n\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')\n            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):\n                nn.init.constant_(m.weight, 1)\n                nn.init.constant_(m.bias, 0)\n\n        # Zero-initialize the last BN in each residual branch,\n        # so that the residual branch starts with zeros, and each residual block behaves like an identity.\n        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677\n        if zero_init_residual:\n            for m in self.modules():\n                if isinstance(m, Bottleneck):\n                    nn.init.constant_(m.bn3.weight, 0)\n                elif isinstance(m, BasicBlock):\n                    nn.init.constant_(m.bn2.weight, 0)\n\n    def _make_layer(self, block, planes, blocks, stride=1, dilate=False):\n        norm_layer = self._norm_layer\n        downsample = None\n        previous_dilation = self.dilation\n        if dilate:\n            self.dilation *= stride\n            stride = 1\n        if stride != 1 or self.inplanes != planes * block.expansion:\n            downsample = nn.Sequential(\n                conv1x1(self.inplanes, planes * block.expansion, stride),\n                norm_layer(planes * block.expansion),\n            )\n\n        layers = []\n        layers.append(block(self.inplanes, planes, stride, downsample, self.groups,\n                            self.base_width, previous_dilation, norm_layer))\n        self.inplanes = planes * block.expansion\n        for _ in range(1, blocks):\n            layers.append(block(self.inplanes, planes, groups=self.groups,\n                                base_width=self.base_width, dilation=self.dilation,\n                                norm_layer=norm_layer))\n\n        return nn.Sequential(*layers)\n\n    def _forward_impl(self, x):\n        # See note [TorchScript super()]\n        x = self.conv1(x)\n        x = self.bn1(x)\n        x = self.relu(x)\n        x = self.maxpool(x)\n\n        x = self.layer1(x)\n        x = self.layer2(x)\n        x = self.layer3(x)\n        x = self.layer4(x)\n\n        x = self.avgpool(x)\n        x = torch.flatten(x, 1)\n        x = self.fc(x)\n\n        return x\n\n    def forward(self, x):\n        return self._forward_impl(x)\n\n\ndef _resnet(arch, block, layers, pretrained, progress, **kwargs):\n    model = ResNet(block, layers, **kwargs)\n    if pretrained:\n        state_dict = load_state_dict_from_url(model_urls[arch],\n                                              progress=progress)\n        model.my_load_state_dict(state_dict)\n    return model\n\n\ndef resnet18(pretrained=False, progress=True, **kwargs):\n    r\"\"\"ResNet-18 model from\n    `\"Deep Residual Learning for Image Recognition\" <https://arxiv.org/pdf/1512.03385.pdf>`_\n\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n        progress (bool): If True, displays a progress bar of the download to stderr\n    \"\"\"\n    return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,\n                   **kwargs)\n\n\ndef resnet34(pretrained=False, progress=True, **kwargs):\n    r\"\"\"ResNet-34 model from\n    `\"Deep Residual Learning for Image Recognition\" <https://arxiv.org/pdf/1512.03385.pdf>`_\n\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n        progress (bool): If True, displays a progress bar of the download to stderr\n    \"\"\"\n    return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,\n                   **kwargs)\n\n\ndef resnet50(pretrained=False, progress=True, **kwargs):\n    r\"\"\"ResNet-50 model from\n    `\"Deep Residual Learning for Image Recognition\" <https://arxiv.org/pdf/1512.03385.pdf>`_\n\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n        progress (bool): If True, displays a progress bar of the download to stderr\n    \"\"\"\n    return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,\n                   **kwargs)\n\n\ndef resnet101(pretrained=False, progress=True, **kwargs):\n    r\"\"\"ResNet-101 model from\n    `\"Deep Residual Learning for Image Recognition\" <https://arxiv.org/pdf/1512.03385.pdf>`_\n\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n        progress (bool): If True, displays a progress bar of the download to stderr\n    \"\"\"\n    return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,\n                   **kwargs)\n\n\ndef resnet152(pretrained=False, progress=True, **kwargs):\n    r\"\"\"ResNet-152 model from\n    `\"Deep Residual Learning for Image Recognition\" <https://arxiv.org/pdf/1512.03385.pdf>`_\n\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n        progress (bool): If True, displays a progress bar of the download to stderr\n    \"\"\"\n    return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,\n                   **kwargs)\n\n\ndef resnext50_32x4d(pretrained=False, progress=True, **kwargs):\n    r\"\"\"ResNeXt-50 32x4d model from\n    `\"Aggregated Residual Transformation for Deep Neural Networks\" <https://arxiv.org/pdf/1611.05431.pdf>`_\n\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n        progress (bool): If True, displays a progress bar of the download to stderr\n    \"\"\"\n    kwargs['groups'] = 32\n    kwargs['width_per_group'] = 4\n    return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],\n                   pretrained, progress, **kwargs)\n\n\ndef resnext101_32x8d(pretrained=False, progress=True, **kwargs):\n    r\"\"\"ResNeXt-101 32x8d model from\n    `\"Aggregated Residual Transformation for Deep Neural Networks\" <https://arxiv.org/pdf/1611.05431.pdf>`_\n\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n        progress (bool): If True, displays a progress bar of the download to stderr\n    \"\"\"\n    kwargs['groups'] = 32\n    kwargs['width_per_group'] = 8\n    return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],\n                   pretrained, progress, **kwargs)\n\n\ndef wide_resnet50_2(pretrained=False, progress=True, **kwargs):\n    r\"\"\"Wide ResNet-50-2 model from\n    `\"Wide Residual Networks\" <https://arxiv.org/pdf/1605.07146.pdf>`_\n\n    The model is the same as ResNet except for the bottleneck number of channels\n    which is twice larger in every block. The number of channels in outer 1x1\n    convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048\n    channels, and in Wide ResNet-50-2 has 2048-1024-2048.\n\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n        progress (bool): If True, displays a progress bar of the download to stderr\n    \"\"\"\n    kwargs['width_per_group'] = 64 * 2\n    return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],\n                   pretrained, progress, **kwargs)\n\n\ndef wide_resnet101_2(pretrained=False, progress=True, **kwargs):\n    r\"\"\"Wide ResNet-101-2 model from\n    `\"Wide Residual Networks\" <https://arxiv.org/pdf/1605.07146.pdf>`_\n\n    The model is the same as ResNet except for the bottleneck number of channels\n    which is twice larger in every block. The number of channels in outer 1x1\n    convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048\n    channels, and in Wide ResNet-50-2 has 2048-1024-2048.\n\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n        progress (bool): If True, displays a progress bar of the download to stderr\n    \"\"\"\n    kwargs['width_per_group'] = 64 * 2\n    return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],\n                   pretrained, progress, **kwargs)\n\n\nif __name__=='__main__':\n    a = resnext101_32x8d(pretrained=True,num_classes=6,num_inputchannels=4)\n    print(a)\n"
  },
  {
    "path": "models/vgg.py",
    "content": "from __future__ import print_function\nfrom __future__ import division\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nimport numpy as np\nimport torchvision\nfrom torchvision import datasets, models, transforms\nimport matplotlib.pyplot as plt\nimport time\nimport os\nimport copy\n\ndef set_parameter_requires_grad(model, feature_extracting):\n    if feature_extracting:\n        for param in model.parameters():\n            param.requires_grad = False\n\ndef create_vgg(num_ic,num_classes,use_pretrained=False,feature_extract=False):\n    model_ft = models.vgg16(pretrained=use_pretrained)\n    set_parameter_requires_grad(model_ft, feature_extract)\n    num_ftrs = model_ft.classifier[6].in_features\n    \n    model_ft.classifier[6] = nn.Linear(num_ftrs,num_classes)\n    modules =[]\n    for i in model_ft.features:\n        modules.append(i)\n    modules[0] = nn.Conv2d(num_ic, 64, kernel_size=3, padding=1) \n    model_ft.features=nn.Sequential(*modules)\n\n\n    modules2=[]\n    for i in model_ft.classifier:\n        modules2.append(i)\n    modules2.append(nn.Tanh())\n    model_ft.classifier = nn.Sequential(*modules2)\n    input_size = 224       \n    return model_ft\nif __name__=='__main__':\n    a = create_vgg(4,6)\n    print(a)\n    inp = torch.ones((1,4,128,128))\n    print(a(inp).shape)\n\n"
  },
  {
    "path": "options/PAMI_options.py",
    "content": "from .base_options import BaseOptions\n\n\nclass TrainOptions(BaseOptions):\n    def initialize(self, parser):\n        parser = BaseOptions.initialize(self, parser)\n        parser.add_argument('--display_freq', type=int, default=40, help='frequency of showing training results on screen')\n        parser.add_argument('--display_ncols', type=int, default=4, help='if positive, display all images in a single visdom web panel with certain number of images per row.')\n        parser.add_argument('--display_id', type=int, default=1, help='window id of the web display')\n        parser.add_argument('--display_server', type=str, default=\"http://bigeye.cs.stonybrook.edu\", help='visdom server of the web display')\n        parser.add_argument('--display_env', type=str, default='main', help='visdom display environment name (default is \"main\")')\n        parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display')\n        parser.add_argument('--update_html_freq', type=int, default=1000, help='frequency of saving training results to html')\n        parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console')\n        parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results')\n        parser.add_argument('--save_epoch_freq', type=int, default=5, help='frequency of saving checkpoints at the end of epochs')\n        parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model')\n        parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by <epoch_count>, <epoch_count>+<save_latest_freq>, ...')\n        parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc')\n        parser.add_argument('--niter', type=int, default=100, help='# of iter at starting learning rate')\n        parser.add_argument('--niter_decay', type=int, default=100, help='# of iter to linearly decay learning rate to zero')\n        parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam')\n        parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam')\n        parser.add_argument('--no_lsgan', action='store_true', help='do *not* use least square GAN, if false, use vanilla GAN')\n        parser.add_argument('--pool_size', type=int, default=50, help='the size of image buffer that stores previously generated images')\n        parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/')\n        parser.add_argument('--lr_policy', type=str, default='lambda', help='learning rate policy: lambda|step|plateau|cosine')\n        parser.add_argument('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations')\n\n        self.isTrain = True\n        return parser\n"
  },
  {
    "path": "options/__init__.py",
    "content": ""
  },
  {
    "path": "options/base_options.py",
    "content": "import argparse\nimport os\nfrom util import util\nimport torch\nimport models\nimport data\n\n\nclass BaseOptions():\n    def __init__(self):\n        self.initialized = False\n\n    def initialize(self, parser):\n        parser.add_argument('--dataroot',  help='path to images (should have subfolders trainA, trainB, valA, valB, etc)')\n        parser.add_argument('--batch_size', type=int, default=1, help='input batch size')\n        parser.add_argument('--loadSize', type=int, default=286, help='scale images to this size')\n        parser.add_argument('--fineSize', type=int, default=256, help='then crop to this size')\n        parser.add_argument('--display_winsize', type=int, default=256, help='display window size for both visdom and HTML')\n        parser.add_argument('--input_nc', type=int, default=3, help='# of input image channels')\n        parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels')\n        parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in first conv layer')\n        parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in first conv layer')\n        parser.add_argument('--netD', type=str, default='basic', help='selects model to use for netD')\n        parser.add_argument('--netG', type=str, default='resnet_9blocks', help='selects model to use for netG')\n        parser.add_argument('--n_layers_D', type=int, default=3, help='only used if netD==n_layers')\n        parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0  0,1,2, 0,2. use -1 for CPU')\n        parser.add_argument('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models')\n        parser.add_argument('--dataset_mode', type=str, default='unaligned', help='chooses how datasets are loaded. [unaligned | aligned | single]')\n        parser.add_argument('--model', type=str, default='cycle_gan',\n                            help='chooses which model to use. cycle_gan, pix2pix, test')\n        parser.add_argument('--direction', type=str, default='AtoB', help='AtoB or BtoA')\n        parser.add_argument('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')\n        parser.add_argument('--num_threads', default=4, type=int, help='# threads for loading data')\n        parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')\n        parser.add_argument('--randomSize', action='store_true', help='if specified, do not flip the images for data augmentation')\n        parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization')\n        parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')\n        parser.add_argument('--no_dropout', action='store_true', help='no dropout for the generator')\n        parser.add_argument('--max_dataset_size', type=int, default=float(\"inf\"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')\n        parser.add_argument('--resize_or_crop', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop|crop|scale_width|scale_width_and_crop|none]')\n        parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation')\n        parser.add_argument('--no_crop', action='store_true', help='if specified, do not flip the images for data augmentation')\n        parser.add_argument('--init_type', type=str, default='xavier', help='network initialization [normal|xavier|kaiming|orthogonal]')\n        parser.add_argument('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.')\n        parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information')\n        parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{loadSize}')\n        parser.add_argument('--lambda_GAN', type=float, default=0.0)\n        parser.add_argument('--lambda_smooth', type=float, default=0.0)\n        parser.add_argument('--lambda_L1', type=float, default=0.0)\n        parser.add_argument('--lambda_bd', type=float, default=0.0)\n        parser.add_argument('--keep_ratio', action='store_true')\n        parser.add_argument('--norm_mean', type=list, default=[0.5,0.5,0.5])\n        parser.add_argument('--norm_std', type=list, default=[0.5,0.5,0.5])\n        parser.add_argument('--finetuning', action='store_true')\n        parser.add_argument('--finetuning_name', type=str)\n        parser.add_argument('--finetuning_epoch', type=str)\n        parser.add_argument('--finetuning_dir', type=str)\n\n        self.initialized = True\n        return parser\n\n    def gather_options(self):\n        # initialize parser with basic options\n        if not self.initialized:\n            parser = argparse.ArgumentParser(\n                formatter_class=argparse.ArgumentDefaultsHelpFormatter)\n            parser = self.initialize(parser)\n\n        # get the basic options\n        opt, _ = parser.parse_known_args()\n\n        # modify model-related parser options\n        model_name = opt.model\n        model_option_setter = models.get_option_setter(model_name)\n        parser = model_option_setter(parser, self.isTrain)\n        opt, _ = parser.parse_known_args()  # parse again with the new defaults\n\n        # modify dataset-related parser options\n        dataset_name = opt.dataset_mode\n        dataset_option_setter = data.get_option_setter(dataset_name)\n        parser = dataset_option_setter(parser, self.isTrain)\n\n        self.parser = parser\n\n        return parser.parse_args()\n\n    def print_options(self, opt):\n        message = ''\n        message += '----------------- Options ---------------\\n'\n        for k, v in sorted(vars(opt).items()):\n            comment = ''\n            default = self.parser.get_default(k)\n            if v != default:\n                comment = '\\t[default: %s]' % str(default)\n            message += '{:>25}: {:<30}{}\\n'.format(str(k), str(v), comment)\n        message += '----------------- End -------------------'\n        print(message)\n\n        # save to the disk\n        expr_dir = os.path.join(opt.checkpoints_dir, opt.name)\n        util.mkdirs(expr_dir)\n        file_name = os.path.join(expr_dir, 'opt.txt')\n        with open(file_name, 'wt') as opt_file:\n            opt_file.write(message)\n            opt_file.write('\\n')\n\n    def parse(self):\n\n        opt = self.gather_options()\n        opt.isTrain = self.isTrain   # train or test\n\n        # process opt.suffix\n        if opt.suffix:\n            suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else ''\n            opt.name = opt.name + suffix\n\n        self.print_options(opt)\n\n        # set gpu ids\n        str_ids = opt.gpu_ids.split(',')\n        opt.gpu_ids = []\n        for str_id in str_ids:\n            id = int(str_id)\n            if id >= 0:\n                opt.gpu_ids.append(id)\n        print(opt.gpu_ids)\n        if len(opt.gpu_ids) > 0:\n            torch.cuda.set_device(opt.gpu_ids[0])\n\n        self.opt = opt\n        return self.opt\n"
  },
  {
    "path": "options/test_options.py",
    "content": "from .base_options import BaseOptions\n\n\nclass TestOptions(BaseOptions):\n    def initialize(self, parser):\n        parser = BaseOptions.initialize(self, parser)\n        parser.add_argument('--ntest', type=int, default=float(\"inf\"), help='# of test examples.')\n        parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.')\n        parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images')\n        parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc')\n        #  Dropout and Batchnorm has different behavioir during training and test.\n        parser.add_argument('--eval', action='store_true', help='use eval mode during test time.')\n        parser.add_argument('--num_test', type=int, default=50, help='how many test images to run')\n\n        parser.set_defaults(model='test')\n        # To avoid cropping, the loadSize should be the same as fineSize\n        parser.set_defaults(loadSize=parser.get_default('fineSize'))\n        self.isTrain = False\n        return parser\n"
  },
  {
    "path": "options/train_options.py",
    "content": "from .base_options import BaseOptions\n\n\nclass TrainOptions(BaseOptions):\n    def initialize(self, parser):\n        parser = BaseOptions.initialize(self, parser)\n        parser.add_argument('--display_freq', type=int, default=40, help='frequency of showing training results on screen')\n        parser.add_argument('--display_ncols', type=int, default=4, help='if positive, display all images in a single visdom web panel with certain number of images per row.')\n        parser.add_argument('--display_id', type=int, default=1, help='window id of the web display')\n        parser.add_argument('--display_server', type=str, default=\"http://bigiris.cs.stonybrook.edu\", help='visdom server of the web display')\n        parser.add_argument('--display_env', type=str, default='main', help='visdom display environment name (default is \"main\")')\n        parser.add_argument('--param_path', type=str)\n        parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display')\n        parser.add_argument('--update_html_freq', type=int, default=1000, help='frequency of saving training results to html')\n        parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console')\n        parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results')\n        parser.add_argument('--save_epoch_freq', type=int, default=5, help='frequency of saving checkpoints at the end of epochs')\n        parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model')\n        parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by <epoch_count>, <epoch_count>+<save_latest_freq>, ...')\n        parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc')\n        parser.add_argument('--niter', type=int, default=100, help='# of iter at starting learning rate')\n        parser.add_argument('--niter_decay', type=int, default=100, help='# of iter to linearly decay learning rate to zero')\n        parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam')\n        parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam')\n        parser.add_argument('--no_lsgan', action='store_true', help='do *not* use least square GAN, if false, use vanilla GAN')\n        parser.add_argument('--pool_size', type=int, default=50, help='the size of image buffer that stores previously generated images')\n        parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/')\n        parser.add_argument('--lr_policy', type=str, default='lambda', help='learning rate policy: lambda|step|plateau|cosine')\n        parser.add_argument('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations')\n\n        parser.add_argument('--n', type=int, default=3)\n        parser.add_argument('--ks', type=int, default=3)\n        parser.add_argument('--rks', type=int, default=3)\n        parser.add_argument('--shadow_loss', type=float, default=0.0)\n        parser.add_argument('--tv_loss', type=float, default=0.0)\n        parser.add_argument('--grad_loss', type=float, default=0.0)\n        parser.add_argument('--pgrad_loss', type=float, default=0.0)\n        parser.add_argument('--load_dir', type=str, default='')\n        parser.add_argument('--optimizer', type=str, default='adam')\n\n        self.isTrain = True\n        return parser\n"
  },
  {
    "path": "util/__init__.py",
    "content": ""
  },
  {
    "path": "util/get_data.py",
    "content": "from __future__ import print_function\nimport os\nimport tarfile\nimport requests\nfrom warnings import warn\nfrom zipfile import ZipFile\nfrom bs4 import BeautifulSoup\nfrom os.path import abspath, isdir, join, basename\n\n\nclass GetData(object):\n    \"\"\"\n\n    Download CycleGAN or Pix2Pix Data.\n\n    Args:\n        technique : str\n            One of: 'cyclegan' or 'pix2pix'.\n        verbose : bool\n            If True, print additional information.\n\n    Examples:\n        >>> from util.get_data import GetData\n        >>> gd = GetData(technique='cyclegan')\n        >>> new_data_path = gd.get(save_path='./datasets')  # options will be displayed.\n\n    \"\"\"\n\n    def __init__(self, technique='cyclegan', verbose=True):\n        url_dict = {\n            'pix2pix': 'https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets',\n            'cyclegan': 'https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets'\n        }\n        self.url = url_dict.get(technique.lower())\n        self._verbose = verbose\n\n    def _print(self, text):\n        if self._verbose:\n            print(text)\n\n    @staticmethod\n    def _get_options(r):\n        soup = BeautifulSoup(r.text, 'lxml')\n        options = [h.text for h in soup.find_all('a', href=True)\n                   if h.text.endswith(('.zip', 'tar.gz'))]\n        return options\n\n    def _present_options(self):\n        r = requests.get(self.url)\n        options = self._get_options(r)\n        print('Options:\\n')\n        for i, o in enumerate(options):\n            print(\"{0}: {1}\".format(i, o))\n        choice = input(\"\\nPlease enter the number of the \"\n                       \"dataset above you wish to download:\")\n        return options[int(choice)]\n\n    def _download_data(self, dataset_url, save_path):\n        if not isdir(save_path):\n            os.makedirs(save_path)\n\n        base = basename(dataset_url)\n        temp_save_path = join(save_path, base)\n\n        with open(temp_save_path, \"wb\") as f:\n            r = requests.get(dataset_url)\n            f.write(r.content)\n\n        if base.endswith('.tar.gz'):\n            obj = tarfile.open(temp_save_path)\n        elif base.endswith('.zip'):\n            obj = ZipFile(temp_save_path, 'r')\n        else:\n            raise ValueError(\"Unknown File Type: {0}.\".format(base))\n\n        self._print(\"Unpacking Data...\")\n        obj.extractall(save_path)\n        obj.close()\n        os.remove(temp_save_path)\n\n    def get(self, save_path, dataset=None):\n        \"\"\"\n\n        Download a dataset.\n\n        Args:\n            save_path : str\n                A directory to save the data to.\n            dataset : str, optional\n                A specific dataset to download.\n                Note: this must include the file extension.\n                If None, options will be presented for you\n                to choose from.\n\n        Returns:\n            save_path_full : str\n                The absolute path to the downloaded data.\n\n        \"\"\"\n        if dataset is None:\n            selected_dataset = self._present_options()\n        else:\n            selected_dataset = dataset\n\n        save_path_full = join(save_path, selected_dataset.split('.')[0])\n\n        if isdir(save_path_full):\n            warn(\"\\n'{0}' already exists. Voiding Download.\".format(\n                save_path_full))\n        else:\n            self._print('Downloading Data...')\n            url = \"{0}/{1}\".format(self.url, selected_dataset)\n            self._download_data(url, save_path=save_path)\n\n        return abspath(save_path_full)\n"
  },
  {
    "path": "util/html.py",
    "content": "import dominate\nfrom dominate.tags import *\nimport os\n\n\nclass HTML:\n    def __init__(self, web_dir, title, reflesh=0):\n        self.title = title\n        self.web_dir = web_dir\n        self.img_dir = os.path.join(self.web_dir, 'images')\n        if not os.path.exists(self.web_dir):\n            os.makedirs(self.web_dir)\n        if not os.path.exists(self.img_dir):\n            os.makedirs(self.img_dir)\n        # print(self.img_dir)\n\n        self.doc = dominate.document(title=title)\n        if reflesh > 0:\n            with self.doc.head:\n                meta(http_equiv=\"reflesh\", content=str(reflesh))\n\n    def get_image_dir(self):\n        return self.img_dir\n\n    def add_header(self, str):\n        with self.doc:\n            h3(str)\n\n    def add_table(self, border=1):\n        self.t = table(border=border, style=\"table-layout: fixed;\")\n        self.doc.add(self.t)\n\n    def add_images(self, ims, txts, links, width=400):\n        self.add_table()\n        with self.t:\n            with tr():\n                for im, txt, link in zip(ims, txts, links):\n                    with td(style=\"word-wrap: break-word;\", halign=\"center\", valign=\"top\"):\n                        with p():\n                            with a(href=os.path.join('images', link)):\n                                img(style=\"width:%dpx\" % width, src=os.path.join('images', im))\n                            br()\n                            p(txt)\n\n    def save(self):\n        html_file = '%s/index.html' % self.web_dir\n        f = open(html_file, 'wt')\n        f.write(self.doc.render())\n        f.close()\n\n\nif __name__ == '__main__':\n    html = HTML('web/', 'test_html')\n    html.add_header('hello world')\n\n    ims = []\n    txts = []\n    links = []\n    for n in range(4):\n        ims.append('image_%d.png' % n)\n        txts.append('text_%d' % n)\n        links.append('image_%d.png' % n)\n    html.add_images(ims, txts, links)\n    html.save()\n"
  },
  {
    "path": "util/image_pool.py",
    "content": "import random\nimport torch\n\n\nclass ImagePool():\n    def __init__(self, pool_size):\n        self.pool_size = pool_size\n        if self.pool_size > 0:\n            self.num_imgs = 0\n            self.images = []\n\n    def query(self, images):\n        if self.pool_size == 0:\n            return images\n        return_images = []\n        for image in images:\n            image = torch.unsqueeze(image.data, 0)\n            if self.num_imgs < self.pool_size:\n                self.num_imgs = self.num_imgs + 1\n                self.images.append(image)\n                return_images.append(image)\n            else:\n                p = random.uniform(0, 1)\n                if p > 0.5:\n                    random_id = random.randint(0, self.pool_size - 1)  # randint is inclusive\n                    tmp = self.images[random_id].clone()\n                    self.images[random_id] = image\n                    return_images.append(tmp)\n                else:\n                    return_images.append(image)\n        return_images = torch.cat(return_images, 0)\n        return return_images\n"
  },
  {
    "path": "util/util.py",
    "content": "from __future__ import print_function\nimport torch\nimport numpy as np\nfrom PIL import Image\nimport os\n\n\ndef sdmkdir(dir_name):\n    if not os.path.exists(dir_name):\n        os.mkdir(dir_name)\n# Converts a Tensor into an image array (numpy)\n# |imtype|: the desired type of the converted numpy array\ndef tensor2im(input_image, imtype=np.uint8,scale=None):\n    \n    if len(input_image.shape)<3: return None\n   # if scale>0 and input_image.size()[1]==3:\n   #     return tensor2im_logc(input_image, imtype=np.uint8,scale=scale)\n\n    if isinstance(input_image, torch.Tensor):\n        image_tensor = input_image.data\n    else:\n        return input_image\n    image_numpy = image_tensor.data[0].cpu().float().numpy()\n    if image_numpy.shape[0] == 1:\n        image_numpy = np.tile(image_numpy, (3, 1, 1))\n    image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0\n    image_numpy[image_numpy<0] = 0\n    image_numpy[image_numpy>255] = 255\n    return image_numpy.astype(imtype)\n\ndef tensor2im_logc(image_tensor, imtype=np.uint8,scale=255):\n    image_numpy = image_tensor.data[0].cpu().double().numpy()\n    image_numpy = np.transpose(image_numpy,(1,2,0))\n    image_numpy = (image_numpy+1) /2.0  \n    image_numpy = image_numpy * (np.log(scale+1)) \n    image_numpy = np.exp(image_numpy) -1\n    image_numpy = image_numpy.astype(np.uint8)\n\n    return image_numpy.astype(np.uint8)\n\n\ndef diagnose_network(net, name='network'):\n    mean = 0.0\n    count = 0\n    for param in net.parameters():\n        if param.grad is not None:\n            mean += torch.mean(torch.abs(param.grad.data))\n            count += 1\n    if count > 0:\n        mean = mean / count\n    print(name)\n    print(mean)\n\n\ndef save_image(image_numpy, image_path):\n    image_pil = Image.fromarray(image_numpy)\n    image_pil.save(image_path)\n\n\ndef print_numpy(x, val=True, shp=False):\n    x = x.astype(np.float64)\n    if shp:\n        print('shape,', x.shape)\n    if val:\n        x = x.flatten()\n        print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % (\n            np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x)))\n\n\ndef mkdirs(paths):\n    if isinstance(paths, list) and not isinstance(paths, str):\n        for path in paths:\n            mkdir(path)\n    else:\n        mkdir(paths)\n\n\ndef mkdir(path):\n    if not os.path.exists(path):\n        os.makedirs(path)\n"
  },
  {
    "path": "util/visualizer.py",
    "content": "import numpy as np\nimport os\nimport ntpath\nimport time\nfrom . import util\nfrom . import html\nfrom pdb import set_trace as st\nclass Visualizer():\n    def __init__(self, opt):\n        # self.opt = opt\n        self.display_id = opt.display_id\n        self.use_html = (opt.isTrain or opt.isTrainMatte) and not opt.no_html\n\n        self.win_size = opt.display_winsize\n        self.name = opt.name\n        if self.display_id > 0:\n            import visdom\n            self.vis = visdom.Visdom(server = opt.display_server,port = opt.display_port)\n            #self.ncols = opt.ncols\n            self.ncols = opt.display_ncols\n        if self.use_html:\n            self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web')\n            self.img_dir = os.path.join(self.web_dir, 'images')\n            print('create web directory %s...' % self.web_dir)\n            util.mkdirs([self.web_dir, self.img_dir])\n        self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt')\n        self.log_name_test = os.path.join(opt.checkpoints_dir, opt.name, 'test_log.txt')\n        with open(self.log_name, \"a\") as log_file:\n            now = time.strftime(\"%c\")\n            log_file.write('================ Training Loss (%s) ================\\n' % now)\n\n    # |visuals|: dictionary of images to display or save\n    def display_current_results(self, visuals, epoch):\n        if self.display_id > 0: # show images in the browser\n            ncols = self.ncols\n            if self.ncols > 0:\n                h, w = next(iter(visuals.values())).shape[:2]\n                table_css = \"\"\"<style>\n    table {border-collapse: separate; border-spacing:4px; white-space:nowrap; text-align:center}\n    table td {width: %dpx; height: %dpx; padding: 4px; outline: 4px solid black}\n</style>\"\"\" % (w, h)\n                ncols = self.ncols\n                title = self.name\n                label_html = ''\n                label_html_row = ''\n                nrows = int(np.ceil(len(visuals.items()) / ncols))\n                images = []\n                idx = 0\n                for label, image_numpy in visuals.items():\n                    label_html_row += '<td>%s</td>' % label\n                    images.append(image_numpy.transpose([2, 0, 1]))\n                    idx += 1\n                    if idx % ncols == 0:\n                        label_html += '<tr>%s</tr>' % label_html_row\n                        label_html_row = ''\n                '''\n                white_image = np.ones_like(image_numpy.transpose([2, 0, 1]))*255\n                while idx % ncols != 0:\n                    images.append(white_image)\n                    label_html_row += '<td></td>'\n                    idx += 1\n                '''\n                if label_html_row != '':\n                    label_html += '<tr>%s</tr>' % label_html_row\n                # pane col = image row\n                self.vis.images(images, nrow=ncols, win=self.display_id + 1,\n                                padding=2, opts=dict(title=title + ' images'))\n                #label_html = '<table>%s</table>' % label_html\n                #self.vis.text(table_css + label_html, win = self.display_id + 2,\n                #              opts=dict(title=title + ' labels'))\n            else:\n                idx = 1\n                for label, image_numpy in visuals.items():\n                    self.vis.image(image_numpy.transpose([2,0,1]), opts=dict(title=label),\n                                       win=self.display_id + idx)\n                    idx += 1\n\n        if self.use_html: # save images to a html file\n            for label, image_numpy in visuals.items():\n                img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label))\n                util.save_image(image_numpy, img_path)\n            # update website\n            webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, reflesh=1)\n            for n in range(epoch, 0, -1):\n                webpage.add_header('epoch [%d]' % n)\n                ims = []\n                txts = []\n                links = []\n\n                for label, image_numpy in visuals.items():\n                    img_path = 'epoch%.3d_%s.png' % (n, label)\n                    ims.append(img_path)\n                    txts.append(label)\n                    links.append(img_path)\n                webpage.add_images(ims, txts, links, width=self.win_size)\n            webpage.save()\n\n    # errors: dictionary of error labels and values\n    def plot_current_losses(self, epoch, counter_ratio, opt, errors):\n        if not hasattr(self, 'plot_data_train'):\n            self.plot_data_train = {'X':[],'Y':[], 'legend':list(errors.keys())}\n        self.plot_data_train['X'].append(epoch + counter_ratio)\n        self.plot_data_train['Y'].append([errors[k] for k in self.plot_data_train['legend']])\n        self.vis.line(\n            X=np.stack([np.array(self.plot_data_train['X'])]*len(self.plot_data_train['legend']),1),\n            Y=np.array(self.plot_data_train['Y']),\n            opts={\n                'title': self.name + ' loss over time',\n                'legend': self.plot_data_train['legend'],\n                'xlabel': 'epoch',\n                'ylabel': 'loss'},\n            win=self.display_id)\n\n    def plot_test_errors(self, epoch, counter_ratio, opt, errors):\n        if not hasattr(self, 'plot_data'):\n            self.plot_data = {'X':[],'Y':[], 'legend':list(errors.keys())}\n        self.plot_data['X'].append(epoch + counter_ratio)\n        self.plot_data['Y'].append([errors[k] for k in self.plot_data['legend']])\n        self.vis.line(\n            X=np.stack([np.array(self.plot_data['X'])]*len(self.plot_data['legend']),1),\n            Y=np.array(self.plot_data['Y']),\n            opts={\n                'title': self.name + ' loss over time',\n                'legend': self.plot_data['legend'],\n                'xlabel': 'epoch',\n                'ylabel': 'loss'},\n            win=self.display_id+10)\n        message = '(epoch: %d)' %(epoch)\n        for k, v in errors.items():\n            message += '%s: %.3f ' % (k, v)\n\n        print(message)\n        with open(self.log_name_test, \"a\") as log_file:\n            log_file.write('%s\\n' % message)\n    \n    # errors: same format as |errors| of plotCurrentErrors\n    def print_current_errors(self, epoch, i, errors, t):\n        message = '(epoch: %d, iters: %d, time: %.3f) ' % (epoch, i, t)\n        for k, v in errors.items():\n            message += '%s: %.3f ' % (k, v)\n\n        print(message)\n        with open(self.log_name, \"a\") as log_file:\n            log_file.write('%s\\n' % message)\n\n    # save image to the disk\n    def save_images(self, webpage, visuals, image_path):\n        image_dir = webpage.get_image_dir()\n        short_path = ntpath.basename(image_path[0])\n        name = os.path.splitext(short_path)[0]\n\n        webpage.add_header(name)\n        ims = []\n        txts = []\n        links = []\n\n        for label, image_numpy in visuals.items():\n            image_name = '%s_%s.png' % (name, label)\n            save_path = os.path.join(image_dir, image_name)\n            util.save_image(image_numpy, save_path)\n\n            ims.append(image_name)\n            txts.append(label)\n            links.append(image_name)\n        webpage.add_images(ims, txts, links, width=self.win_size)\n"
  }
]