[
  {
    "path": ".gitignore",
    "content": "*.*~\n*~\n__pycache__\nscripts/logs_*"
  },
  {
    "path": "LICENSE",
    "content": "MIT License\n\nCopyright (c) 2021 Boyuan Chen\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "README.md",
    "content": "# Discovering State Variables Hidden in Experimental Data\n\n[Boyuan Chen](http://boyuanchen.com/),\n[Kuang Huang](https://cm3.apam.columbia.edu/people/kuang-huang),\n[Sunand Raghupathi](https://www.linkedin.com/in/sunand-raghupathi?trk=public_profile_browsemap),\n[Ishaan Chandratreya](https://www.linkedin.com/in/ishaan-chandratreya-546aa6141),\n[Qiang Du](https://www.engineering.columbia.edu/faculty/qiang-du),\n[Hod Lipson](https://www.hodlipson.com/)\n<br>\nColumbia University\n<br>\n\n### [Project Website](https://www.cs.columbia.edu/~bchen/neural-state-variables) | [Video](https://youtu.be/KWHXchlJzSw) | [Paper](http://arxiv.org/abs/2112.10755)\n\n## Overview\nThis repo contains the PyTorch implementation for paper \"Discovering State Variables Hidden in Experimental Data\".\n\n![teaser](figures/teaser.gif)\n\n## Citation\n\nIf you find our paper or codebase helpful, please consider citing:\n\n```\n@article{chen2021discover,\n  title={Discovering State Variables Hidden in Experimental Data},\n  author={Chen, Boyuan and Huang, Kuang and Raghupathi, Sunand and Chandratreya, Ishaan and Du, Qiang and Lipson, Hod},\n  journal={arXiv preprint arXiv:2112.10755},\n  year={2021}\n}\n```\n\n## Content\n\n- [Installation](#installation)\n- [Logging](#logging)\n- [Data Preparation](#data-preparation)\n- [Training and Testing](#training-and-testing)\n- [Intrinsic Dimension Estimation](#intrinsic-dimension-estimation)\n- [Long-term Prediction and Stability Evaluation](#long-term-prediction-and-stability-evaluation)\n- [Evaluation and Analysis](#evaluation-analysis)\n- [License](#license)\n\n## Installation\n\nThe installation has been test on Ubuntu 18.04 with CUDA 11.0. All the experiments are performed on one GeForce RTX 2080 Ti Nvidia GPU.\n\nCreate a python virtual environment and install the dependencies.\n```\nvirtualenv -p /usr/bin/python3.6 env3.6\nsource env3.6/bin/activate\npip install -r requirements.txt\n```\n**Note**: You may need to install Matlab on your computer to use some of the data collectors and intrinsic dimension estimation algorithms.\n\n## Logingg\n\nWe first introduce the naming convention of the saved files so that it is clear what will be saved and where they will be saved.\n\n1. Log folder naming convention:\n    ```\n    logs_{dataset}_{model_name}_{seed}\n    ```\n2. Inside the logs folder, the structure and contents are:\n    ```\n    \\logs_{dataset}_{model_name}_{seed}\n        \\lightning_logs\n            \\checkpoints               [saved checkpoint]\n            \\version_0                 [training stats]\n            \\version_1                 [testing stats]\n        \\predictions                   [testing predicted images]\n        \\prediction_long_term          [long term predicted images]\n        \\variables                     [file id and latent vectors on testing data]\n        \\variables_train               [file id and latent vectors on training data]\n        \\variables_val                 [file id and latent vectors on validation data]\n    ```\n\n## Data Preparation\n\nWe provide nine datasets with their own download links below.\n\n- [circular_motion](https://drive.google.com/file/d/19DFYqh08B2L-YX_bTmpmxYu2jpARqoUA/view?usp=sharing) (circular motion system)\n- [reaction_diffusion](https://drive.google.com/file/d/1LOqpB0l86lELEegX5sV9NwHkGKeHoaTN/view?usp=sharing) (reaction diffusion system)\n- [single_pendulum](https://drive.google.com/file/d/1rw6vrVcV3KCIaVVwKMMhoZ9Jph5u-Zdf/view?usp=sharing) (single pendulum system)\n- [double_pendulum](https://drive.google.com/file/d/1QEtk4JjnRysIEjtkZIKBjACu_IiTAdX6/view?usp=sharing) (rigid double pendulum system)\n- [elastic_pendulum](https://drive.google.com/file/d/1Y8vzawQZhzPHp6cpLFuyM2LHuhgLjH9H/view?usp=sharing) (elastic double pendulum system)\n- [swingstick_non_magnetic](https://drive.google.com/file/d/1BfeGW4XTFyGdyBO0G_YnSnRyJGu2WRnc/view?usp=sharing) (swing stick system)\n- [air_dancer](https://drive.google.com/file/d/163KvwevY1fnDnI6WiWpc5Zaz-WWwPaAS/view?usp=sharing) (air dancer system)\n- [lava_lamp](https://drive.google.com/file/d/1R-I2CZaJLe2D4H818-n25l54R0pbKQfo/view?usp=sharing) (lava lamp system)\n- [fire](https://drive.google.com/file/d/1OIH6SalwPyD_lkXBLZq6bzmKfrp-xhrI/view?usp=sharing) (fire system)\n\nSave the downloaded dataset as ```data/{dataset_name}```, where ```data``` is your customized dataset folder. **Please make sure that ```data``` is an absolute path and you need to change the ```data_filepath``` item in the ```config.yaml``` files in ```configs``` to specify your customized dataset folder**.\n\nPlease refer to the [datainfo](datainfo) folder for more details about data structure and dataset collection process.\n\n\n## Training and Testing\n\nOur approach involves three models:\n- dynamics predictive model (encoder-decoder / encoder-decoder-64)\n- latent reconstruction model (refine-64)\n- neural latent dynamics model (latentpred)\n\n1. Navigate to the scripts folder\n    ```\n    cd scripts\n    ```\n\n2. Train the dynamics predictive model (encoder-decoder and encoder-decoder-64) and then save the high-dimensional latent vectors from the testing data.\n    ```\n    ./encoder_decoder_64_train.sh {dataset_name} {gpu no.}\n    ./encoder_decoder_train.sh {dataset_name} {gpu no.}\n    ./encoder_decoder_64_eval.sh {dataset_name} {gpu no.}\n    ./encoder_decoder_eval.sh {dataset_name} {gpu no.}\n    ```\n\n3. Run forward pass on the trained encoder-deocer model and encoder-decoder-64 model to save the high-dimensional latent vectors from the training and validation data. The saved latent vectors will be used as the training and validataion data for training and validating the latent reconstruction model.\n    ```\n    ./encoder_decoder_eval_gather.sh {dataset_name} {gpu no.}\n    ./encoder_decoder_64_eval_gather.sh {dataset_name} {gpu no.}\n    ```\n\n4. Before you proceed this step, please refer to the next section to obtain the system's intrinsic dimension and then come back to this step. \n\n    Train the latent reconstruction model (refine-64) with the saved 64-dim latent vectors from previous steps. Then save the obtained Neural State Variables from both training and testing data.\n    ```\n    ./refine_64_train.sh {dataset_name} {gpu no.}\n    ./refine_64_eval.sh {dataset_name} {gpu no.}\n    ./refine_64_eval_gather.sh {dataset_name} {gpu no.}\n    ```\n\n5. Train the neural latent dynamics model (latentpred) with the trained models from previous steps.\n    ```\n    ./latentpred_train.sh single_pendulum {dataset_name} {gpu no.}\n    ```\n\n## Intrinsic Dimension Estimation\n\nWith the trained dynamics predictive model, our approach provides subroutines to estimate the system's intrinsic dimension (ID) using manifold learning algorithms. The estimated intrinsic dimension will be used to decide the number of Neural State Variables and to design the latent reconstruction model. Only after this step you can proceed to train the latent reconstruction model to obtain Neural State Variables.\n\n1. Navigate to the scripts folder\n    ```\n    cd scripts\n    ```\n    which is the default directory saving all models' log folders.\n\n2. Estimate the intrinsic dimension from the saved latent vectors of the encoder-decoder models for all random seeds. \n    ```\n    ./encoder_decoder_estimate_dimension.sh {dataset_name}\n    ```\n\n3. Calculate the final intrinsic dimension estimated values (mean and standard deviation).\n    ```\n    python ../utils/dimension.py {dataset_name}\n    ```\n\n\n## Long-term Prediction and Stability Evaluation\n\nWith all above trained models, our approach offers system long-term predictions through model rollouts as well as stability evaluation of the long-term predictions.\n\n1. Navigate to the scripts folder\n    ```\n    cd scripts\n    ```\n\n2. Long-term prediction with single model rollouts.\n    ```\n    ./encoder_decoder_long_term_model_rollout.sh {dataset_name} {gpu no.}\n    ./encoder_decoder_64_long_term_model_rollout.sh {dataset_name} {gpu no.}\n    ./refine_64_long_term_model_rollout.sh {dataset_name} {gpu no.}\n    ```\n    The predictions will be saved in the ```prediction_long_term``` subfolder under the model's log folder.\n\n3. Long-term prediction with hybrid model rollouts.\n    ```\n    ./refine_64_long_term_hybrid_rollout.sh {dataset_name} {gpu no.} {step}\n    ```\n    where ```step``` is the number of model rollouts via 64-dim latent vectors before a model rollout via Neural State Variables.\n\n    The predictions will be saved in the ```prediction_long_term``` subfolder under the refine-64 model's log folder.\n\n4. Long-term prediction with single model rollouts from perturbed initial frames.\n    ```\n    ./encoder_decoder_long_term_model_rollout_perturb_all.sh {dataset_name} {gpu no.}\n    ./encoder_decoder_64_long_term_model_rollout_perturb_all.sh {dataset_name} {gpu no.}\n    ./refine_64_long_term_model_rollout_perturb_all.sh {dataset_name} {gpu no.}\n    ```\n    \n    The predictions will be saved in the ```prediction_long_term``` subfolder under the model's log folder.\n\n3. Stability evaluation on long-term predictions with single model rollouts\n    ```\n    ./long_term_eval_stability.sh {dataset_name} {gpu no.}\n    ```\n    and on long-term predictions with hybrid model rollouts\n    ```\n    ./long_term_eval_stability_hybrid.sh {dataset_name} {gpu no.} {step}\n    ```\n    where ```step``` is as mentioned above for hybrid model rollouts.\n    \n    The evaluated latent space errors measuring the prediction stability will be saved in the ```stability.npy``` file under the respective ```prediction_long_term``` folders.\n\n**Note**: The default long-term prediction length is 60 (in frames). You will need to modify the scripts if you want to use a different prediction length.\n\n\n## Evaluation and Analysis\n\nPlease refer to the [analysis](analysis) folder for detailed instructions for physical evaluation and analysis.\n\n## License\n\nThis repository is released under the MIT license. See [LICENSE](LICENSE) for additional details.\n"
  },
  {
    "path": "analysis/README.md",
    "content": "\n## Physics Evaluation on Known Systems\n\nWe provide physics evaluation algorithms extracting physical variables including positions, velocities, and energies from ground truth and predicted frames for the three known systems (single pendulum, rigid double pendulum, and elastic double pendulum).\nYou can use these algorithms to evaluate the physical accuracy of predictions.\n\n1. Evaluate physical variables from ground truth data.\n    ```\n    python eval_phys_data.py dataset_name data_filepath\n    ```\n    The ```dataset_name``` can be single_pendulum, double_pendulum, or elastic_pendulum. The ```data_filepath``` is the respective data filepath, e.g., ```data/single_pendulum``` where ```data``` is your customized data folder. The results will be saved in the ```phys_vars.npy``` file under ```data_filepath```.\n\n2. Evaluate physical variables from long-term predictions via model rollouts and compute physical errors between ground truth data and the predictions.\n    ```\n    python eval_phys_long_term_pred.py config_filepath pred_save_path\n    ```\n    The ```config_filepath``` is the model configuration filepath, e.g., ```../configs/single_pendulum/model/config1.yaml```. The ```pred_save_path``` is the path to the predicted images, e.g., ```../scripts/logs_single_pendulum_encoder-decoder_1/prediction_long_term/model_rollout/```. The results will be saved in the ```phys_vars.npy```, ```phys_error.npy```, and ```pixel_error.npy``` files under ```pred_save_path```.\n\n**Note**: ```eval_phys_single_pendulum```,  ```eval_phys_double_pendulum```, and ```eval_phys_elastic_pendulum``` are helper packages for evaluating physical variables for the above command lines.\n\n\n## Intrinsic Dimension Estimation\n\n1. Navigate to the scripts folder\n    ```\n    cd ../scripts\n    ```\n    which is the default directory saving all models' log folders.\n\n2. Estimate intrinsic dimension from model latent vectors using the Levina-Bickel method\n    ```\n    python ../analysis/eval_intrinsic_dimension.py {config_filepath} model-latent NA\n    ```\n    or using all methods (Levina-Bickel, MiND-ML, MiND-KL, Hein, CD)\n    ```\n    python ../analysis/eval_intrinsic_dimension.py {config_filepath} model-latent all-methods\n    ```\n    The ```config_filepath``` is the model configuration filepath, e.g., ```../configs/single_pendulum/model/config1.yaml```. The results will be saved in the ```intrinsic_dimension.npy``` file (using the Levina-Bickel method) or the ```intrinsic_dimension_all_methods.npy``` file (using all methods) in the ```variables``` subfolder under the model's log folder.\n\n3. Estimate intrinsic dimension from raw data (image pairs) using the Levina-Bickel method\n    ```\n    python ../analysis/eval_intrinsic_dimension.py {config_filepath} data-image NA\n    ```\n    or using all methods (Levina-Bickel, MiND-ML, MiND-KL, Hein, CD)\n    ```\n    python ../analysis/eval_intrinsic_dimension.py {config_filepath} data-image all-methods\n    ```\n    The ```config_filepath``` is the model configuration filepath, e.g., ```../configs/single_pendulum/model/config1.yaml```. The results will be saved in the ```intrinsic_dimension_image.npy``` file (using the Levina-Bickel method) or the ```intrinsic_dimension_image_all_methods.npy``` file (using all methods) in the ```variables``` subfolder under the model's log folder. Here the model configuration only provides data filepath and test video ids.\n\n**Note**: ```intrinsic_dimension_estimation``` are helper packages providing various intrinsic dimension estimation methods. When choosing the ```all-methods``` option, the MATLAB and MATLAB Engine API for Python should be installed on the machine, see instructions [here](https://www.mathworks.com/help/matlab/matlab_external/install-the-matlab-engine-for-python.html).\n\nReferences of MATLAB codes:\n```\n[1] Gabriele Lombardi (2021). Intrinsic dimensionality estimation techniques (https://www.mathworks.com/matlabcentral/fileexchange/40112-intrinsic-dimensionality-estimation-techniques), MATLAB Central File Exchange. Retrieved October 21, 2021.\n[2] M. Hein and J.-Y. Audibert, Intrinsic dimensionality estimation of submanifolds in Euclidean space, Proceedings of the 22nd Internatical Conference on Machine Learning (https://www.ml.uni-saarland.de/code/IntDim/IntDim.htm).\n```\n\n\n## Latent Space Regression on Known Systems\n\n1. Navigate to the scripts folder\n    ```\n    cd ../scripts\n    ```\n    which is the default directory saving all models' log folders.\n\n2. Regress physical variables from the model latent vectors\n    ```\n    python ../analysis/eval_regression.py config_filepath NA\n    ```\n    or from the first ```num_components``` principal components of the model latent vectors\n    ```\n    python ../analysis/eval_regression.py config_filepath num_components\n    ```\n    The ```config_filepath``` is the model configuration filepath, e.g., ```../configs/single_pendulum/model/config1.yaml```. The results will be saved in the ```regression_results.npy``` file (with model latent vectors) or the ```regression_results_pca_{num_components}.npy``` file (with first few principal components of model latent vectors) in the ```variables``` subfolder under the model's log folder."
  },
  {
    "path": "analysis/eval_intrinsic_dimension.py",
    "content": "import numpy as np\nimport os\nimport sys\nfrom tqdm import tqdm\nfrom PIL import Image\nimport json\nimport yaml\nimport pprint\nfrom munch import munchify\nfrom intrinsic_dimension_estimation import ID_Estimator\n\n\ndef load_config(filepath):\n    with open(filepath, 'r') as stream:\n        try:\n            trainer_params = yaml.safe_load(stream)\n            return trainer_params\n        except yaml.YAMLError as exc:\n            print(exc)\n\ndef remove_duplicates(X):\n    return np.unique(X, axis=0)\n\n\ndef eval_id_latent(vars_filepath, if_refine, if_all_methods):\n    if if_refine:\n        latent = np.load(os.path.join(vars_filepath, 'refine_latent.npy'))\n    else:\n        latent = np.load(os.path.join(vars_filepath, 'latent.npy'))\n    print(f'Number of samples: {latent.shape[0]}; Latent dimension: {latent.shape[1]}')\n    latent = remove_duplicates(latent)\n    print(f'Number of samples (duplicates removed): {latent.shape[0]}')\n    \n    estimator = ID_Estimator()\n    k_list = (latent.shape[0] * np.linspace(0.008, 0.016, 5)).astype('int')\n    print(f'List of numbers of nearest neighbors: {k_list}')\n    if if_all_methods:\n        dims = estimator.fit_all_methods(latent, k_list)\n        np.save(os.path.join(vars_filepath, 'intrinsic_dimension_all_methods.npy'), dims)\n    else:\n        dims = estimator.fit(latent, k_list)\n        np.save(os.path.join(vars_filepath, 'intrinsic_dimension.npy'), dims)\n\n\ndef eval_id_image(data_filepath, test_vid_ids, vars_filepath, if_all_methods):\n    print('Reading image data...')\n    data = []\n    for p_vid_idx in tqdm(test_vid_ids):\n        vid_filepath = os.path.join(data_filepath, str(p_vid_idx))\n        num_frames = len(os.listdir(vid_filepath))\n        suf = os.listdir(vid_filepath)[0].split('.')[-1]\n        for p_frame in range(num_frames - 3):\n            img_list = []\n            for p in range(2):\n                img = Image.open(os.path.join(vid_filepath, str(p_frame + p) + '.' + suf))\n                img = img.resize((128, 128))\n                img = np.array(img) / 255.0\n                img_list.append(img)\n            data_p = np.concatenate(img_list, 1).reshape([1, -1])\n            data.append(data_p)\n    data = np.concatenate(data, 0)\n    print(f'Number of samples: {data.shape[0]}; Image dimension (flattened): {data.shape[1]}')\n    data = remove_duplicates(data)\n    print(f'Number of samples (duplicates removed): {data.shape[0]}')\n\n    estimator = ID_Estimator()\n    k_list = (data.shape[0] * np.linspace(0.008, 0.016, 5)).astype('int')\n    print(f'List of numbers of nearest neighbors: {k_list}')\n    if if_all_methods:\n        dims = estimator.fit_all_methods(data, k_list)\n        np.save(os.path.join(vars_filepath, 'intrinsic_dimension_image_all_methods.npy'), dims)\n    else:\n        dims = estimator.fit(data, k_list)\n        np.save(os.path.join(vars_filepath, 'intrinsic_dimension_image.npy'), dims)\n\n\nif __name__ == '__main__':\n    \n    config_filepath = str(sys.argv[1])\n    cfg = load_config(filepath=config_filepath)\n    pprint.pprint(cfg)\n    cfg = munchify(cfg)\n    if_all_methods = (str(sys.argv[3]) == 'all-methods')\n    \n    if str(sys.argv[2]) == 'model-latent':\n        log_dir = '_'.join([cfg.log_dir, cfg.dataset, cfg.model_name, str(cfg.seed)])\n        vars_filepath = os.path.join(log_dir, 'variables')\n        if_refine = ('refine' in cfg.model_name)\n        dims = eval_id_latent(vars_filepath, if_refine, if_all_methods)\n    \n    elif str(sys.argv[2]) == 'data-image':\n        data_filepath = os.path.join(cfg.data_filepath, cfg.dataset)\n        with open(os.path.join('../datainfo', cfg.dataset, f'data_split_dict_{cfg.seed}.json'), 'r') as file:\n            seq_dict = json.load(file)\n        test_vid_ids = seq_dict['test']\n        log_dir = '_'.join([cfg.log_dir, cfg.dataset, cfg.model_name, str(cfg.seed)])\n        vars_filepath = os.path.join(log_dir, 'variables')\n        dims = eval_id_image(data_filepath, test_vid_ids, vars_filepath, if_all_methods)\n    \n    else:\n        assert False, 'Invalid arguments...'\n"
  },
  {
    "path": "analysis/eval_phys_data.py",
    "content": "import os\nimport sys\nimport cv2\nimport numpy as np\nfrom tqdm import tqdm\n\n\ndef eval_phys_data_single_pendulum(data_filepath, num_vids, num_frms, save_path):\n    from eval_phys_single_pendulum import eval_physics, phys_vars_list\n    phys = {p_var:[] for p_var in phys_vars_list}\n\n    for n in tqdm(range(num_vids)):\n        seq_filepath = os.path.join(data_filepath, str(n))\n        frames = []\n        for p in range(num_frms):\n            frame_p = cv2.imread(os.path.join(seq_filepath, str(p)+'.png'))\n            frames.append(frame_p)\n        phys_tmp = eval_physics(frames)\n        for p_var in phys_vars_list:\n            phys[p_var].append(phys_tmp[p_var])\n    for p_var in phys_vars_list:\n        phys[p_var] = np.array(phys[p_var])\n\n    np.save(save_path, phys)\n\n\ndef eval_phys_data_double_pendulum(data_filepath, num_vids, num_frms, save_path):\n    from eval_phys_double_pendulum import eval_physics, phys_vars_list\n    phys = {p_var:[] for p_var in phys_vars_list}\n\n    for n in tqdm(range(num_vids)):\n        seq_filepath = os.path.join(data_filepath, str(n))\n        frames = []\n        for p in range(num_frms):\n            frame_p = cv2.imread(os.path.join(seq_filepath, str(p)+'.png'))\n            frames.append(frame_p)\n        phys_tmp = eval_physics(frames)\n        for p_var in phys_vars_list:\n            phys[p_var].append(phys_tmp[p_var])\n    for p_var in phys_vars_list:\n        phys[p_var] = np.array(phys[p_var])\n\n    # remove outliers\n    thresh_1 = np.nanpercentile(np.abs(phys['vel_theta_1']), 98)\n    thresh_2 = np.nanpercentile(np.abs(phys['vel_theta_2']), 98)\n    for n in range(num_vids):\n        for p in range(num_frms):\n            if (not np.isnan(phys['vel_theta_1'][n, p]) and np.abs(phys['vel_theta_1'][n, p]) >= thresh_1) \\\n            or (not np.isnan(phys['vel_theta_2'][n, p]) and np.abs(phys['vel_theta_2'][n, p]) >= thresh_2):\n                phys['vel_theta_1'][n, p] = np.nan\n                phys['vel_theta_2'][n, p] = np.nan\n                phys['kinetic energy'][n, p] = np.nan\n                phys['total energy'][n, p] = np.nan\n\n    np.save(save_path, phys)\n\n\ndef eval_phys_data_elastic_pendulum(data_filepath, num_vids, num_frms, save_path):\n    from eval_phys_elastic_pendulum import eval_physics, phys_vars_list\n    phys = {p_var:[] for p_var in phys_vars_list}\n\n    for n in tqdm(range(num_vids)):\n        seq_filepath = os.path.join(data_filepath, str(n))\n        frames = []\n        for p in range(num_frms):\n            frame_p = cv2.imread(os.path.join(seq_filepath, str(p)+'.png'))\n            frames.append(frame_p)\n        phys_tmp = eval_physics(frames)\n        for p_var in phys_vars_list:\n            phys[p_var].append(phys_tmp[p_var])\n    for p_var in phys_vars_list:\n        phys[p_var] = np.array(phys[p_var])\n\n    # remove outliers\n    thresh_1 = np.nanpercentile(np.abs(phys['vel_theta_1']), 98)\n    thresh_2 = np.nanpercentile(np.abs(phys['vel_theta_2']), 98)\n    thresh_z = np.nanpercentile(np.abs(phys['vel_z']), 98)\n    for n in range(num_vids):\n        for p in range(num_frms):\n            if (not np.isnan(phys['vel_theta_1'][n, p]) and np.abs(phys['vel_theta_1'][n, p]) >= thresh_1) \\\n            or (not np.isnan(phys['vel_theta_2'][n, p]) and np.abs(phys['vel_theta_2'][n, p]) >= thresh_2) \\\n            or (not np.isnan(phys['vel_z'][n, p]) and np.abs(phys['vel_z'][n, p]) >= thresh_z):\n                phys['vel_theta_1'][n, p] = np.nan\n                phys['vel_theta_2'][n, p] = np.nan\n                phys['vel_z'][n, p] = np.nan\n                phys['kinetic energy'][n, p] = np.nan\n                phys['total energy'][n, p] = np.nan\n\n    np.save(save_path, phys)\n\n\nif __name__ == '__main__':\n    dataset = str(sys.argv[1])\n    data_filepath = str(sys.argv[2])\n    save_path = os.path.join(data_filepath, 'phys_vars.npy')\n    \n    if dataset == 'single_pendulum':\n        eval_phys_data_single_pendulum(data_filepath, 1200, 60, save_path)\n    elif dataset == 'double_pendulum':\n        eval_phys_data_double_pendulum(data_filepath, 1100, 60, save_path)\n    elif dataset == 'elastic_pendulum':\n        eval_phys_data_elastic_pendulum(data_filepath, 1200, 60, save_path)\n    else:\n        assert False, 'Unknown system...'"
  },
  {
    "path": "analysis/eval_phys_double_pendulum/__init__.py",
    "content": "from .angle_estimator import obtain_angle\nfrom .physics_estimator import *\n\nimport numpy as np\n\n\nphys_vars_list = {'reject', 'theta_1', 'vel_theta_1', 'theta_2', 'vel_theta_2', 'kinetic energy',\n'potential energy', 'total energy'}\n\n\ndef eval_physics(frames):\n    num_frames = len(frames)\n    reject = np.zeros(num_frames, dtype=bool)\n    theta_1 = np.zeros(num_frames)\n    theta_2 = np.zeros(num_frames)\n    # estimate angles\n    for p in range(num_frames):\n        reject_p, angles_p, _ = obtain_angle(frames[p])\n        if reject_p:\n            reject[p] = True\n            theta_1[p] = np.nan\n            theta_2[p] = np.nan\n        else:\n            reject[p] = False\n            theta_1[p] = angles_p[0]\n            theta_2[p] = angles_p[1]\n    # calculate velocities\n    vel_theta_1 = np.zeros(num_frames)\n    vel_theta_2 = np.zeros(num_frames)\n    sub_ids = np.ma.clump_unmasked(np.ma.masked_array(theta_1, reject))\n    for ids in sub_ids:\n        vel_theta_1[ids] = calc_velocity(theta_1[ids].copy())\n        vel_theta_2[ids] = calc_velocity(theta_2[ids].copy())\n    vel_theta_1[reject] = np.nan\n    vel_theta_2[reject] = np.nan\n    # calculate energies\n    kinetic_energy, potential_energy, total_energy = calc_energy(theta_1, theta_2, vel_theta_1, vel_theta_2)\n    # save results\n    phys = dict.fromkeys(phys_vars_list)\n    phys['reject'] = reject\n    phys['theta_1'] = theta_1\n    phys['vel_theta_1'] = vel_theta_1\n    phys['theta_2'] = theta_2\n    phys['vel_theta_2'] = vel_theta_2\n    phys['kinetic energy'] = kinetic_energy\n    phys['potential energy'] = potential_energy\n    phys['total energy'] = total_energy\n    return phys"
  },
  {
    "path": "analysis/eval_phys_double_pendulum/angle_estimator.py",
    "content": "'''\nThis script provides utility functions estimating angles of a double pendulum from image.\nStep 1. Extract each of the two pendulums from the image.\nStep 2. Do a rectangle fitting of each pendulum.\nStep 3. Estimate the angles of the pendulums. \nCertain images will be rejected if any pendulum does not exist, is hidden, or has a wrong shape.\n'''\n\nimport cv2\nimport os\nimport numpy as np\n\n\n'''\nExtract the two pendulums from the image.\nArgs:\n    img: double pendulum image in BGR format\nReturns:\n    seg_1: segmentation of the first pendulum\n    seg_2: segmentation of the second pendulum\n'''\ndef seg_from_img(img):\n    # pixel thresholds (in HSV)\n    v_min_1 = (0, 0, 0)\n    v_max_1 = (255, 255, 140)\n    v_min_2 = (60, 0, 0)\n    v_max_2 = (255, 255, 255)\n    # background color in HSV (RGB: [215, 205, 192])\n    bg_color = [17, 27, 215]\n\n    img_hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)\n    # extract the first pendulum\n    seg_1 = cv2.inRange(img_hsv, v_min_1, v_max_1)\n    # remove the first pendulum by replacing it with background pixels\n    img_hsv[seg_1==255] = bg_color\n    # extract the second pendulum\n    seg_2 = cv2.inRange(img_hsv, v_min_2, v_max_2)\n\n    return seg_1, seg_2\n\n\n'''\nFit pendulum in a rectangle.\nArgs:\n    seg: segmentation of the pendulum\nReturns:\n    rej: (True/False) if the image is rejected\n    rect: (Box2D structure) the fitted rectangle\n'''\ndef fit_pendulum(seg):\n    # find all contours\n    contours, _ = cv2.findContours(seg, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)\n    # reject if no contours found\n    if len(contours) == 0:\n        return True, None\n    # find the contour with the maximum area\n    cnt = max(contours, key=cv2.contourArea)\n    area = cv2.contourArea(cnt)\n    # reject if the contour is too small\n    if area < 125:\n        return True, None\n    # rectangle fitting\n    rect = cv2.minAreaRect(cnt)\n    _, (width, height), _ = rect\n    # reject if the rectangle is too close to a square\n    if abs(height-width) < 8:\n        return True, None\n    # reject if the rectangle does not properly fit the contour\n    if height*width > 2 * area:\n        return True, None\n\n    return False, rect\n\n\n'''\nEstimate the angle of pendulum from its fitted rectangle.\nArgs:\n    rect: (Box2D structure) the fitted rectangle\n    ref_pt: reference point to determine the arrow direction\nReturns:\n    angle: the estimated angle in radians in range (0, 2*pi)\n    box: box points of the rectangle\n    arrow: the arrow pointing along the angle\n'''\ndef estimate_angle(rect, ref_pt):\n    # box points\n    box = cv2.boxPoints(rect)\n    # center, width and height\n    (cx, cy), (width, height), _ = rect\n    # specify the direction vector of the pendulum\n    if width < height:\n        v_d = box[0] - box[1]\n    else:\n        v_d = box[0] - box[3]\n    # reference vector\n    v_ref = np.array([cx, cy]) - ref_pt\n    # choose the one from v_d and -v_d whose angle with v_ref\n    # is less than pi\n    if np.dot(v_d, v_ref) < 0:\n        v_d = -v_d\n    # counterclockwise angle from (0,-1) to v_d\n    angle = np.arctan2(-v_d[1], v_d[0]) + np.pi/2\n    if angle < 0:\n        angle += 2*np.pi\n    arrow = ((int(cx), int(cy)), (int(cx+v_d[0]), int(cy+v_d[1])))\n    return angle, box, arrow\n\n\n'''\nObtain the angles of double pendulum from image\nArgs:\n    img: double pendulum image in BGR format\nReturns:\n    rej: (True/False) if the image is rejected\n    angles: the estimated angles in radians in range (0, 2*pi)\n    img_marked: image marked with the fitted rectangles,\n    the direction vectors and the estimated angles (BGR format)\n'''\ndef obtain_angle(img):\n    img_marked = img.copy()\n    seg_1, seg_2 = seg_from_img(img)\n    rej_1, rect_1 = fit_pendulum(seg_1)\n    rej_2, rect_2 = fit_pendulum(seg_2)\n    if rej_1:\n        rej = True\n        angles = []\n        cv2.putText(img_marked, 'Reject arm 1', (10, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1)\n    elif rej_2:\n        rej = True\n        angles = []\n        cv2.putText(img_marked, 'Reject arm 2', (10, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1)\n    else:\n        rej = False\n        img_center = np.array([64, 64])\n        angle_1, box_1, arrow_1 = estimate_angle(rect_1, img_center)\n        (cx, cy), _, _ = rect_1\n        pd_1_end = 2 * np.array([cx, cy]) - img_center\n        angle_2, box_2, arrow_2 = estimate_angle(rect_2, pd_1_end)\n        angles = [angle_1, angle_2]\n        # mark the fitted rectangles\n        cv2.drawContours(img_marked, [np.int0(box_1)], 0, (0,0,255), 2)\n        cv2.drawContours(img_marked, [np.int0(box_2)], 0, (0,0,255), 2)\n        # mark the direction vectors\n        cv2.arrowedLine(img_marked, arrow_1[0], arrow_1[1], (0, 0, 255), 1, tipLength=0.25)\n        cv2.arrowedLine(img_marked, arrow_2[0], arrow_2[1], (0, 0, 255), 1, tipLength=0.25)\n        # mark the estimated angles in degrees\n        text = str(round(angle_1*180/np.pi)) + ' ' + str(round(angle_2*180/np.pi))\n        cv2.putText(img_marked, text, (10, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1)\n\n    return rej, angles, img_marked"
  },
  {
    "path": "analysis/eval_phys_double_pendulum/physics_estimator.py",
    "content": "'''\nThis script provides utility functions estimating angular velocities\nand other physical quantities from angles of double pendulum.\n'''\n\nimport numpy as np\nfrom scipy.interpolate import CubicSpline\n\n# physical parameters\nfps = 60 # frames per second\nL1, L2 = 0.205, 0.179  # pendulum rod lengths (m)\nw1, w2 = 0.038, 0.038  # pendulum rod widths (m)\nm1, m2 = 0.262, 0.110  # bob masses (kg)\ng = 9.81  # gravitational acceleration (m/s^2)\n\n'''\nCalculate the absolute difference between two angles th1 and th2 on a circle.\nAssumed that the absolute difference between the two angles is within range (0,pi).\n'''\ndef calc_diff(th1, th2):\n    diff = np.abs(th2 - th1)\n    diff = np.minimum(diff, 2*np.pi-diff)\n    return diff\n\n'''\nCalculate the average of two angles th1 and th2 on a circle.\nAssumed that the absolute difference between the two angles is within range (0,pi).\n'''\ndef calc_avrg(th1, th2):\n        avrg = (th1 + th2) / 2\n        diff = np.abs(th2 - th1)\n        if diff > np.pi:\n            avrg -= np.pi\n        if avrg < 0:\n            avrg += 2*np.pi\n        return avrg \n\n'''\nCalculate angular velocities from a sequence of angles\nusing numerical differentiation.\nmethod='fd': finite difference;\nmethod='spline': cubic spline fitting.\n'''\ndef calc_velocity(th, method='spline'):\n    len_seq = th.shape[0]\n\n    # isolated data\n    if len_seq == 1:\n        return np.nan\n    \n    # preprocessing: periodic extension of angles\n    for i in range(1, len_seq):\n        if th[i] - th[i-1] > np.pi:\n            th[i:] -= 2*np.pi\n        elif th[i] - th[i-1] < -np.pi:\n            th[i:] += 2*np.pi\n    \n    vel_th = np.zeros(len_seq)\n\n    # finite difference\n    if method == 'fd':\n        for i in range(1, len_seq):\n            vel_th[i] = (th[i] - th[i-1]) * fps\n        vel_th[0] = (th[1] - th[0]) * fps\n    \n    # cubic spline fitting\n    elif method == 'spline':\n        t = np.arange(len_seq) / fps\n        cs = CubicSpline(t, th)\n        vel_th = cs(t, 1)\n        # use finite difference at boundary points to improve accuracy\n        vel_th[0] = (th[1] - th[0]) * fps\n        vel_th[-1] = (th[-1] - th[-2]) * fps\n    \n    else:\n        assert False, 'Unrecognizable differentiation method!'\n    \n    return vel_th\n\n'''\nCalculate energies from angles and angular velocities\n'''\ndef calc_energy(th1, th2, vel_th1, vel_th2):\n    # centers of masses in x-y coordinates\n    x1 = (L1 / 2) * np.sin(th1)\n    y1 = (-L1 / 2) * np.cos(th1)\n\n    x2 = (L1 * np.sin(th1)) + (L2 / 2) * np.sin(th2)\n    y2 = (-L1 * np.cos(th1)) - (L2 / 2) * np.cos(th2)\n\n    # velocities in x-y coordinates\n    vel_x1 = vel_th1 * (L1 / 2) * np.cos(th1)\n    vel_y1 = vel_th1 * (L1 / 2) * np.sin(th1)\n\n    vel_x2 = vel_th1 * L1 * np.cos(th1) + (vel_th2 / 2) * L2 * np.cos(th2)\n    vel_y2 = vel_th1 * L1 * np.sin(th1) + (vel_th2 / 2) * L2 * np.sin(th2)\n\n    # moments of inertia\n    I1 = (1. / 12.) * m1 * (w1 ** 2 + L1 ** 2)\n    I2 = (1. / 12.) * m2 * (w2 ** 2 + L2 ** 2)\n\n    # potential energy\n    V = g * (m1 * y1 + m2 * y2)\n    # kinetic energy (translation + rotation)\n    T = 0.5 * m1 * (vel_x1 ** 2 + vel_y1 ** 2) + 0.5 * m2 * (vel_x2 ** 2 + vel_y2 ** 2) + 0.5 * I1 * (\n                vel_th1 ** 2) + 0.5 * I2 * (vel_th2 ** 2)\n    # total energy\n    E = T + V\n\n    return T, V, E"
  },
  {
    "path": "analysis/eval_phys_elastic_pendulum/__init__.py",
    "content": "from .angle_estimator import obtain_angle_stretch\nfrom .physics_estimator import *\n\nimport numpy as np\n\n\nphys_vars_list = {'reject', 'theta_1', 'vel_theta_1', 'theta_2', 'vel_theta_2', 'z', 'vel_z',\n'kinetic energy', 'potential energy', 'total energy'}\n\n\ndef eval_physics(frames):\n    num_frames = len(frames)\n    reject = np.zeros(num_frames, dtype=bool)\n    theta_1 = np.zeros(num_frames)\n    theta_2 = np.zeros(num_frames)\n    z = np.zeros(num_frames)\n    # estimate angles and stretch\n    for p in range(num_frames):\n        reject_p, angles_p, stretch_p, _ = obtain_angle_stretch(frames[p])\n        if reject_p:\n            reject[p] = True\n            theta_1[p] = np.nan\n            theta_2[p] = np.nan\n            z[p] = np.nan\n        else:\n            reject[p] = False\n            theta_1[p] = angles_p[0]\n            theta_2[p] = angles_p[1]\n            z[p] = stretch_p\n    # calculate velocities\n    vel_theta_1 = np.zeros(num_frames)\n    vel_theta_2 = np.zeros(num_frames)\n    vel_z = np.zeros(num_frames)\n    sub_ids = np.ma.clump_unmasked(np.ma.masked_array(theta_1, reject))\n    for ids in sub_ids:\n        vel_theta_1[ids] = calc_velocity(theta_1[ids].copy())\n        vel_theta_2[ids] = calc_velocity(theta_2[ids].copy())\n        vel_z[ids] = calc_velocity(z[ids].copy(), periodic=False)\n    vel_theta_1[reject] = np.nan\n    vel_theta_2[reject] = np.nan\n    vel_z[reject] = np.nan\n    # calculate energies\n    kinetic_energy, potential_energy, total_energy = calc_energy(theta_1, theta_2, z, vel_theta_1, vel_theta_2, vel_z)\n    # save results\n    phys = dict.fromkeys(phys_vars_list)\n    phys['reject'] = reject\n    phys['theta_1'] = theta_1\n    phys['vel_theta_1'] = vel_theta_1\n    phys['theta_2'] = theta_2\n    phys['vel_theta_2'] = vel_theta_2\n    phys['z'] = z\n    phys['vel_z'] = vel_z\n    phys['kinetic energy'] = kinetic_energy\n    phys['potential energy'] = potential_energy\n    phys['total energy'] = total_energy\n    return phys"
  },
  {
    "path": "analysis/eval_phys_elastic_pendulum/angle_estimator.py",
    "content": "'''\nThis script provides utility functions estimating angles and stretch of a elastic double pendulum from image.\nStep 1. Extract each of the two pendulums from the image.\nStep 2. Do a rectangle fitting of each pendulum.\nStep 3. Estimate the angles of the pendulums and the stretch of the elastic one. \nCertain images will be rejected if any pendulum does not exist, is hidden, or has a wrong shape.\n'''\n\nimport cv2\nimport os\nimport numpy as np\n\n\n'''\nExtract the two pendulums from the image.\nArgs:\n    img: elastic double pendulum image in BGR format\nReturns:\n    seg_1: segmentation of the first pendulum\n    seg_2: segmentation of the second pendulum\n'''\ndef seg_from_img(img):\n    # pixel thresholds (in HSV)\n    v_min_1 = (0, 0, 0)\n    v_max_1 = (255, 255, 140)\n    v_min_2 = (60, 0, 0)\n    v_max_2 = (255, 255, 255)\n    # background color in HSV (RGB: [215, 205, 192])\n    bg_color = [17, 27, 215]\n\n    img_hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)\n    # extract the first pendulum\n    seg_1 = cv2.inRange(img_hsv, v_min_1, v_max_1)\n    # remove the first pendulum by replacing it with background pixels\n    img_hsv[seg_1==255] = bg_color\n    # extract the second pendulum\n    seg_2 = cv2.inRange(img_hsv, v_min_2, v_max_2)\n\n    return seg_1, seg_2\n\n\n'''\nFit pendulum in a rectangle.\nArgs:\n    seg: segmentation of the pendulum\nReturns:\n    rej: (True/False) if the image is rejected\n    rect: (Box2D structure) the fitted rectangle\n'''\ndef fit_pendulum(seg):\n    # find all contours\n    contours, _ = cv2.findContours(seg, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)\n    # reject if no contours found\n    if len(contours) == 0:\n        return True, None\n    # find the contour with the maximum area\n    cnt = max(contours, key=cv2.contourArea)\n    area = cv2.contourArea(cnt)\n    # reject if the contour is too small\n    if area < 60:\n        return True, None\n    # rectangle fitting\n    rect = cv2.minAreaRect(cnt)\n    _, (width, height), _ = rect\n    # reject if the rectangle is too close to a square\n    if abs(height-width) < 4:\n        return True, None\n    # reject if the rectangle does not properly fit the contour\n    if height*width > 1.8 * area:\n        return True, None\n\n    return False, rect\n\n\n'''\nEstimate the angle of pendulum from its fitted rectangle.\nArgs:\n    rect: (Box2D structure) the fitted rectangle\n    ref_pt: reference point to determine the arrow direction\nReturns:\n    angle: the estimated angle in radians in range (0, 2*pi)\n    box: box points of the rectangle\n    arrow: the arrow pointing along the angle\n'''\ndef estimate_angle(rect, ref_pt):\n    # box points\n    box = cv2.boxPoints(rect)\n    # center, width and height\n    (cx, cy), (width, height), _ = rect\n    # specify the direction vector of the pendulum\n    if width < height:\n        v_d = box[0] - box[1]\n    else:\n        v_d = box[0] - box[3]\n    # reference vector\n    v_ref = np.array([cx, cy]) - ref_pt\n    # choose the one from v_d and -v_d whose angle with v_ref\n    # is less than pi\n    if np.dot(v_d, v_ref) < 0:\n        v_d = -v_d\n    # counterclockwise angle from (0,-1) to v_d\n    angle = np.arctan2(-v_d[1], v_d[0]) + np.pi/2\n    if angle < 0:\n        angle += 2*np.pi\n    arrow = ((int(cx), int(cy)), (int(cx+v_d[0]), int(cy+v_d[1])))\n    return angle, box, arrow\n\n\n'''\nObtain the angles and stretch of elastic double pendulum from image\nArgs:\n    img: elastic double pendulum image in BGR format\nReturns:\n    rej: (True/False) if the image is rejected\n    angles: the estimated angles in radians in range (0, 2*pi)\n    stretch: the estimated stretch of the elastic pendulum (unit: m)\n    img_marked: image marked with the fitted rectangles,\n    the direction vectors and the estimated angles and stretch (BGR format)\n'''\ndef obtain_angle_stretch(img):\n    img_marked = img.copy()\n    seg_1, seg_2 = seg_from_img(img)\n    rej_1, rect_1 = fit_pendulum(seg_1)\n    rej_2, rect_2 = fit_pendulum(seg_2)\n    if rej_1:\n        rej = True\n        angles = []\n        stretch = None\n        cv2.putText(img_marked, 'Reject arm 1', (10, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1)\n    elif rej_2:\n        rej = True\n        angles = []\n        stretch = None\n        cv2.putText(img_marked, 'Reject arm 2', (10, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1)\n    else:\n        rej = False\n        img_center = np.array([64, 64])\n        angle_1, box_1, arrow_1 = estimate_angle(rect_1, img_center)\n        (cx, cy), (width, height), _ = rect_1\n        pd_1_end = 2 * np.array([cx, cy]) - img_center\n        angle_2, box_2, arrow_2 = estimate_angle(rect_2, pd_1_end)\n        angles = [angle_1, angle_2]\n        # compute the stretch of the elastic pendulum (unit: m)\n        # max(width, height): stretched length of the elastic pendulum in pixels\n        L_0_img = 24  # unstretched length of the elastic pendulum in pixels\n        L_0 = 0.205   # unstretched length of the elastic pendulum (m)\n        stretch = (max(width, height) - L_0_img) / L_0_img * L_0\n        # mark the fitted rectangles\n        cv2.drawContours(img_marked, [np.int0(box_1)], 0, (0,0,255), 2)\n        cv2.drawContours(img_marked, [np.int0(box_2)], 0, (0,0,255), 2)\n        # mark the direction vectors\n        cv2.arrowedLine(img_marked, arrow_1[0], arrow_1[1], (0, 0, 255), 1, tipLength=0.25)\n        cv2.arrowedLine(img_marked, arrow_2[0], arrow_2[1], (0, 0, 255), 1, tipLength=0.25)\n        # mark the estimated angles (degrees) and stretch (m)\n        text = str(round(angle_1*180/np.pi)) + ' ' + str(round(angle_2*180/np.pi)) + ' ' + ('%.2f' % stretch)\n        cv2.putText(img_marked, text, (10, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.45, (0, 0, 0), 1)\n\n    return rej, angles, stretch, img_marked"
  },
  {
    "path": "analysis/eval_phys_elastic_pendulum/physics_estimator.py",
    "content": "'''\nThis script provides utility functions estimating velocities\nand other physical quantities of elastic double pendulum.\n'''\n\nimport numpy as np\nfrom scipy.interpolate import CubicSpline\n\n# physical parameters\nfps = 60     # frames per second\nL_0 = 0.205  # elastic pendulum unstretched length (m)\nL = 0.179    # rigid pendulum rod length (m)\nw = 0.038    # rigid pendulum rod width (m)\nm = 0.110    # rigid pendulum mass (kg)\ng = 9.81     # gravitational acceleration (m/s^2)\nk = 40.0     # elastic constant (kg/s^2)\n\n'''\nCalculate the absolute difference between two angles th1 and th2 on a circle.\nAssumed that the absolute difference between the two angles is within range (0,pi).\n'''\ndef calc_diff(th1, th2):\n    diff = np.abs(th2 - th1)\n    diff = np.minimum(diff, 2*np.pi-diff)\n    return diff\n\n'''\nCalculate the average of two angles th1 and th2 on a circle.\nAssumed that the absolute difference between the two angles is within range (0,pi).\n'''\ndef calc_avrg(th1, th2):\n        avrg = (th1 + th2) / 2\n        diff = np.abs(th2 - th1)\n        if diff > np.pi:\n            avrg -= np.pi\n        if avrg < 0:\n            avrg += 2*np.pi\n        return avrg \n\n'''\nNormalize the angle to (-pi, pi).\n'''\ndef normalize_angle(theta):\n    return np.arctan2(np.sin(theta), np.cos(theta))\n\n'''\nCalculate velocities from a sequence of data\nusing numerical differentiation.\nmethod='fd': finite difference;\nmethod='spline': cubic spline fitting.\n'''\ndef calc_velocity(x, method='spline', periodic=True):\n    len_seq = x.shape[0]\n\n    # isolated data\n    if len_seq == 1:\n        return np.nan\n    \n    # preprocessing: periodic extension of angles\n    if periodic:\n        for i in range(1, len_seq):\n            if x[i] - x[i-1] > np.pi:\n                x[i:] -= 2*np.pi\n            elif x[i] - x[i-1] < -np.pi:\n                x[i:] += 2*np.pi\n    \n    vel_x = np.zeros(len_seq)\n\n    # finite difference\n    if method == 'fd':\n        for i in range(1, len_seq):\n            vel_x[i] = (x[i] - x[i-1]) * fps\n        vel_x[0] = (x[1] - x[0]) * fps\n    \n    # cubic spline fitting\n    elif method == 'spline':\n        t = np.arange(len_seq) / fps\n        cs = CubicSpline(t, x)\n        vel_x = cs(t, 1)\n        # use finite difference at boundary points to improve accuracy\n        vel_x[0] = (x[1] - x[0]) * fps\n        vel_x[-1] = (x[-1] - x[-2]) * fps\n    \n    else:\n        assert False, 'Unrecognizable differentiation method!'\n    \n    return vel_x\n\n'''\nCalculate energies\n'''\ndef calc_energy(th1, th2, z, vel_th1, vel_th2, vel_z):\n    # center of mass in x-y coordinates\n    x = (L_0 + z) * np.sin(th1) + (L / 2) * np.sin(th2)\n    y = -(L_0 + z) * np.cos(th1) - (L / 2) * np.cos(th2)\n\n    # velocities in x-y coordinates\n    vel_x = vel_th1 * (L_0 + z) * np.cos(th1) + vel_th2 * (L / 2) * np.cos(th2) + vel_z * np.sin(th1)\n    vel_y = vel_th1 * (L_0 + z) * np.sin(th1) + vel_th2 * (L / 2) * np.sin(th2) - vel_z * np.cos(th1)\n\n    # moment of inertia\n    I = (1. / 12.) * m * (w ** 2 + L ** 2)\n\n    # potential energy (gravitational + elastic)\n    V = m * g * y + 0.5 * k * z**2\n\n    # kinetic energy (translation + rotation)\n    T = 0.5 * m * (vel_x ** 2 + vel_y ** 2) + 0.5 * I * vel_th2 ** 2\n    \n    # total energy\n    E = T + V\n\n    return T, V, E"
  },
  {
    "path": "analysis/eval_phys_long_term_pred.py",
    "content": "\"\"\"\nEvaluate long-term prediction accuracy in physical variables.\n\"\"\"\n\nimport os\nimport sys\nimport numpy as np\nfrom scipy import stats\nimport cv2\nimport json\nimport yaml\nimport pprint\nfrom tqdm import tqdm\nfrom munch import munchify\n\n\n'''\nCalculate the absolute difference between two angles th1 and th2 on a circle.\nAssumed that the absolute difference between the two angles is within range (0,pi).\n'''\ndef calc_diff(th1, th2):\n    diff = np.abs(th2 - th1)\n    diff = np.minimum(diff, 2*np.pi-diff)\n    return diff\n\n'''\nCalculate the pixel mean square error between two images.\n'''\ndef calc_pixel_MSE(img1, img2):\n    return np.mean((img1/255.0 - img2/255.0) ** 2)\n\ndef load_config(filepath):\n    with open(filepath, 'r') as stream:\n        try:\n            trainer_params = yaml.safe_load(stream)\n            return trainer_params\n        except yaml.YAMLError as exc:\n            print(exc)\n\n\ndef eval_pred_physics(data_filepath, pred_save_path, vid_ids, phys_vars_list, eval_physics):\n    phys = {p_var:[] for p_var in phys_vars_list}\n\n    for idx in tqdm(vid_ids):\n        data_vid_filepath = os.path.join(data_filepath, str(idx))\n        pred_vid_filepath = os.path.join(pred_save_path, str(idx))\n        num_frames = len(os.listdir(data_vid_filepath))\n        num_pred_frames = len(os.listdir(pred_vid_filepath))\n        frames = []\n\n        for p in range(num_frames):\n            if (num_pred_frames == num_frames - 2) and (p == 0 or p == 1):\n                # read the initial two frames from ground truth data\n                frame_p = cv2.imread(os.path.join(data_vid_filepath, str(p)+'.png'))\n            else:\n                frame_p = cv2.imread(os.path.join(pred_vid_filepath, str(p)+'.png'))\n            frames.append(frame_p)\n        phys_tmp = eval_physics(frames)\n        \n        for p_var in phys_vars_list:\n            phys[p_var].append(phys_tmp[p_var])\n\n    for p_var in phys_vars_list:\n        phys[p_var] = np.array(phys[p_var])\n\n    return phys\n\n\ndef load_data_physics(data_filepath, vid_ids, phys_vars_list):\n    phys = np.load(os.path.join(data_filepath, 'phys_vars.npy'), allow_pickle=True).item()\n    for p_var in phys_vars_list:\n        phys[p_var] = phys[p_var][vid_ids]\n    return phys\n\n\ndef eval_physics_error(phys_pred, phys_data, phys_vars_list):\n    phys_vars_list_2 = [p_var for p_var in phys_vars_list if p_var!='reject']\n    phys_error = {}\n\n    phys_error['reject'] = phys_pred['reject'].copy()\n    if 'reject' in phys_data.keys():\n        phys_error['reject_data'] = phys_data['reject'].copy()\n    else:\n        phys_error['reject_data'] = np.zeros(phys_pred['reject'].shape)\n\n    for p_var in phys_vars_list_2:\n        if p_var in ['theta', 'theta_1', 'theta_2']:\n            phys_error[p_var] = calc_diff(phys_pred[p_var], phys_data[p_var])\n        else:\n            phys_error[p_var] = np.abs(phys_pred[p_var] - phys_data[p_var])\n\n    return phys_error\n\n\ndef eval_pixel_error(data_filepath, pred_save_path, vid_ids):\n    pixel_error = []\n\n    for idx in tqdm(vid_ids):\n        data_vid_filepath = os.path.join(data_filepath, str(idx))\n        pred_vid_filepath = os.path.join(pred_save_path, str(idx))\n        num_frames = len(os.listdir(data_vid_filepath))\n        pixel_error_idx = []\n\n        for p in range(num_frames):\n            if p == 0 or p == 1:\n                pixel_error_idx.append(0)\n            else:\n                data = cv2.imread(os.path.join(data_vid_filepath, str(p)+'.png'))\n                pred = cv2.imread(os.path.join(pred_vid_filepath, str(p)+'.png'))\n                pixel_error_idx.append(calc_pixel_MSE(data, pred))\n        \n        pixel_error.append(pixel_error_idx)\n\n    return np.array(pixel_error)\n\n\ndef main():\n    config_filepath = str(sys.argv[1])\n    pred_save_path = str(sys.argv[2])\n\n    cfg = load_config(filepath=config_filepath)\n    pprint.pprint(cfg)\n    cfg = munchify(cfg)\n\n    data_filepath = os.path.join(cfg.data_filepath, cfg.dataset)\n    with open(os.path.join('../datainfo', cfg.dataset, f'data_split_dict_{cfg.seed}.json'), 'r') as file:\n        seq_dict = json.load(file)\n    vid_ids = seq_dict['test']\n\n    if cfg.dataset == 'single_pendulum':\n        from eval_phys_single_pendulum import phys_vars_list, eval_physics\n    elif cfg.dataset == 'double_pendulum':\n        from eval_phys_double_pendulum import phys_vars_list, eval_physics  \n    elif cfg.dataset == 'elastic_pendulum':\n        from eval_phys_elastic_pendulum import phys_vars_list, eval_physics\n    else:\n        assert False, 'Unknown system...'\n\n    phys_pred = eval_pred_physics(data_filepath, pred_save_path, vid_ids, phys_vars_list, eval_physics)\n    phys_data = load_data_physics(data_filepath, vid_ids, phys_vars_list)\n\n    phys_error = eval_physics_error(phys_pred, phys_data, phys_vars_list)\n    pixel_error = eval_pixel_error(data_filepath, pred_save_path, vid_ids)\n\n    np.save(os.path.join(pred_save_path, 'phys_vars.npy'), phys_pred)\n    np.save(os.path.join(pred_save_path, 'phys_error.npy'), phys_error)\n    np.save(os.path.join(pred_save_path, 'pixel_error.npy'), pixel_error)\n\n\nif __name__ == '__main__':    \n    main()"
  },
  {
    "path": "analysis/eval_phys_single_pendulum/__init__.py",
    "content": "from .angle_estimator import obtain_angle\nfrom .physics_estimator import *\n\nimport numpy as np\n\n\nphys_vars_list = {'reject', 'theta', 'vel_theta', 'kinetic energy',\n'potential energy', 'total energy'}\n\n\ndef eval_physics(frames):\n    num_frames = len(frames)\n    reject = np.zeros(num_frames, dtype=bool)\n    theta = np.zeros(num_frames)\n    # estimate angles\n    for p in range(num_frames):\n        reject_p, theta_p, _ = obtain_angle(frames[p])\n        if reject_p:\n            reject[p] = True\n            theta[p] = np.nan\n        else:\n            reject[p] = False\n            theta[p] = theta_p\n    # calculate velocities\n    vel_theta = np.zeros(num_frames)\n    sub_ids = np.ma.clump_unmasked(np.ma.masked_array(theta, reject))\n    for ids in sub_ids:\n        vel_theta[ids] = calc_velocity(theta[ids].copy())\n    vel_theta[reject] = np.nan\n    # calculate energies\n    kinetic_energy, potential_energy, total_energy = calc_energy(theta, vel_theta)\n    # save results\n    phys = dict.fromkeys(phys_vars_list)\n    phys['reject'] = reject\n    phys['theta'] = theta\n    phys['vel_theta'] = vel_theta\n    phys['kinetic energy'] = kinetic_energy\n    phys['potential energy'] = potential_energy\n    phys['total energy'] = total_energy\n    return phys"
  },
  {
    "path": "analysis/eval_phys_single_pendulum/angle_estimator.py",
    "content": "'''\nThis script provides utility functions estimating the angle of a single pendulum from image.\nStep 1. Extract the pendulum from the image.\nStep 2. Do a rectangle fitting of the pendulum.\nStep 3. Estimate the angle of the pendulum. \nCertain images will be rejected if the pendulum does not exist or has a wrong shape.\n'''\n\nimport cv2\nimport os\nimport numpy as np\n\n\n'''\nExtract the pendulum from the image.\nArgs:\n    img: single pendulum image in BGR format\nReturns:\n    seg: segmentation of the pendulum\n'''\ndef seg_from_img(img):\n    # pixel thresholds (in HSV)\n    v_min = (60, 0, 0)\n    v_max = (255, 255, 255)\n    # extract the pendulum\n    img_hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)\n    seg = cv2.inRange(img_hsv, v_min, v_max)\n\n    return seg\n\n\n'''\nFit pendulum in a rectangle.\nArgs:\n    seg: segmentation of the pendulum\nReturns:\n    rej: (True/False) if the image is rejected\n    rect: (Box2D structure) the fitted rectangle\n'''\ndef fit_pendulum(seg):\n    # find all contours\n    contours, _ = cv2.findContours(seg, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)\n    # reject if no contours found\n    if len(contours) == 0:\n        return True, None\n    # find the contour with the maximum area\n    cnt = max(contours, key=cv2.contourArea)\n    area = cv2.contourArea(cnt)\n    # reject if the contour is too small\n    if area < 400:\n        return True, None\n    # rectangle fitting\n    rect = cv2.minAreaRect(cnt)\n    _, (width, height), _ = rect\n    # reject if the rectangle is too close to a square\n    if abs(height-width) < 35:\n        return True, None\n    # reject if the rectangle does not properly fit the contour\n    if height*width > 1.5 * area:\n        return True, None\n    \n    return False, rect\n\n\n'''\nEstimate the angle of pendulum from its fitted rectangle.\nArgs:\n    rect: (Box2D structure) the fitted rectangle\nReturns:\n    angle: the estimated angle in radians in range (0, 2*pi)\n    box: box points of the rectangle\n    arrow: the arrow pointing along the angle\n'''\ndef estimate_angle(rect):\n    # box points\n    box = cv2.boxPoints(rect)\n    # center, width and height\n    (cx, cy), (width, height), _ = rect\n    # specify the direction vector of the pendulum\n    if width < height:\n        v_d = box[0] - box[1]\n    else:\n        v_d = box[0] - box[3]\n    # reference vector\n    v_ref = np.array([cx-64, cy-64])\n    # choose the one from v_d and -v_d whose angle with v_ref\n    # is less than pi\n    if np.dot(v_d, v_ref) < 0:\n        v_d = -v_d\n    # counterclockwise angle from (0,-1) to v_d\n    angle = np.arctan2(-v_d[1], v_d[0]) + np.pi/2\n    if angle < 0:\n        angle += 2*np.pi\n    arrow = ((int(cx), int(cy)), (int(cx+0.7*v_d[0]), int(cy+0.7*v_d[1])))\n    return angle, box, arrow\n\n\n'''\nObtain the angle of single pendulum from image\nArgs:\n    img: single pendulum image in BGR format\nReturns:\n    rej: (True/False) if the image is rejected\n    angle: the estimated angle in radians in range (0, 2*pi)\n    img_marked: image marked with the fitted rectangle,\n    the direction vector and the estimated angle (BGR format)\n'''\ndef obtain_angle(img):\n    img_marked = img.copy()\n    seg = seg_from_img(img)\n    rej, rect = fit_pendulum(seg)\n\n    if not rej:\n        angle, box, arrow = estimate_angle(rect)\n        # mark the fitted rectangle\n        cv2.drawContours(img_marked, [np.int0(box)], 0, (0,0,255), 2)\n        # mark the direction vector\n        cv2.arrowedLine(img_marked, arrow[0], arrow[1], (0, 0, 255), 1, tipLength=0.25)\n        # mark the estimated angle in degrees\n        cv2.putText(img_marked, str(round(angle*180/np.pi)), (10, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1)\n    else:\n        # mark the rejection\n        cv2.putText(img_marked, 'Reject', (10, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1)\n        angle = np.nan\n    \n    return rej, angle, img_marked"
  },
  {
    "path": "analysis/eval_phys_single_pendulum/physics_estimator.py",
    "content": "'''\nThis script provides utility functions estimating angular velocities\nand other physical quantities from angles of single pendulum.\n'''\n\nimport numpy as np\nfrom scipy.interpolate import CubicSpline\n\n# physical parameters\nfps = 60  # frames per second\nl = 0.5   # pendulum rod length (m)\nm = 1.0   # bob mass (kg)\ng = 9.81  # gravitational acceleration (m/s^2)\n\n\n'''\nCalculate the absolute difference between two angles th1 and th2 on a circle.\nAssumed that the absolute difference between the two angles is within range (0,pi).\n'''\ndef calc_diff(th1, th2):\n    diff = np.abs(th2 - th1)\n    diff = np.minimum(diff, 2*np.pi-diff)\n    return diff\n\n'''\nCalculate the average of two angles th1 and th2 on a circle.\nAssumed that the absolute difference between the two angles is within range (0,pi).\n'''\ndef calc_avrg(th1, th2):\n    avrg = (th1 + th2) / 2\n    diff = np.abs(th2 - th1)\n    if diff > np.pi:\n        avrg -= np.pi\n    if avrg < 0:\n        avrg += 2*np.pi\n    return avrg \n\n'''\nNormalize the angle to (-pi, pi).\n'''\ndef normalize_angle(theta):\n    return np.arctan2(np.sin(theta), np.cos(theta))\n\n'''\nCalculate angular velocities from a sequence of angles\nusing numerical differentiation.\nmethod='fd': finite difference;\nmethod='spline': cubic spline fitting.\n'''\ndef calc_velocity(th, method='spline'):\n    len_seq = th.shape[0]\n\n    # isolated data\n    if len_seq == 1:\n        return np.nan\n    \n    # preprocessing: periodic extension of angles\n    for i in range(1, len_seq):\n        if th[i] - th[i-1] > np.pi:\n            th[i:] -= 2*np.pi\n        elif th[i] - th[i-1] < -np.pi:\n            th[i:] += 2*np.pi\n    \n    vel_th = np.zeros(len_seq)\n\n    # finite difference\n    if method == 'fd':\n        for i in range(1, len_seq):\n            vel_th[i] = (th[i] - th[i-1]) * fps\n        vel_th[0] = (th[1] - th[0]) * fps\n    \n    # cubic spline fitting\n    elif method == 'spline':\n        t = np.arange(len_seq) / fps\n        cs = CubicSpline(t, th)\n        vel_th = cs(t, 1)\n        # use finite difference at boundary points to improve accuracy\n        vel_th[0] = (th[1] - th[0]) * fps\n        vel_th[-1] = (th[-1] - th[-2]) * fps\n    \n    else:\n        assert False, 'Unrecognizable differentiation method!'\n    \n    return vel_th\n\n'''\nCalculate energies from angles and angular velocities\nMoment of inertia of the pendulum: I=ml^2/3\n'''\ndef calc_energy(th, vel_th):\n    T = m * l**2 / 6 * vel_th**2\n    V = - m * g * l / 2 * np.cos(th)\n    E = T + V\n    return T, V, E\n"
  },
  {
    "path": "analysis/eval_regression.py",
    "content": "import os\nimport sys\nimport numpy as np\nfrom tqdm import tqdm\nimport json\nimport yaml\nimport pprint\nfrom munch import munchify\nfrom latent_regression import pca, mlp_regress\n\n\ndef load_config(filepath):\n    with open(filepath, 'r') as stream:\n        try:\n            trainer_params = yaml.safe_load(stream)\n            return trainer_params\n        except yaml.YAMLError as exc:\n            print(exc)\n\n\n# parse the video no. and frame no. from the data id (example : '7_0.png')\ndef parse_data_id(data_id):\n    lhs, rhs = data_id.split('_')\n    vid_n = int(lhs)\n    frm_n = int(rhs.split('.')[0])\n    return vid_n, frm_n\n\n\ndef physical_variables_from_data_ids(phys_all, ids):\n    phys_vars_list = []\n    for p_var in phys_all.keys():\n        if p_var == 'reject':\n            continue\n        for t in range(4):\n            phys_vars_list.append(f'{p_var} (t={t})')\n    \n    num_data = ids.shape[0]\n    phys = {p_var:np.zeros(num_data) for p_var in phys_vars_list}\n    \n    for n in range(num_data):\n        vid_n, frm_n = parse_data_id(ids[n])\n        for p_var in phys_all.keys():\n            if p_var == 'reject':\n                continue\n            for t in range(4):\n                phys[f'{p_var} (t={t})'][n] = phys_all[p_var][vid_n, frm_n+t]\n    \n    return phys\n\n\ndef main():\n    config_filepath = str(sys.argv[1])\n    cfg = load_config(filepath=config_filepath)\n    pprint.pprint(cfg)\n    cfg = munchify(cfg)\n\n    if_refine = ('refine' in cfg.model_name)\n    phys_all = np.load(os.path.join(cfg.data_filepath, cfg.dataset, 'phys_vars.npy'), allow_pickle=True).item()\n    log_dir = '_'.join([cfg.log_dir, cfg.dataset, cfg.model_name, str(cfg.seed)])\n\n    # latent vectors and physical variables (train and test)\n    ids_train = np.load(os.path.join(log_dir, 'variables_train', 'ids.npy'))\n    ids_test = np.load(os.path.join(log_dir, 'variables', 'ids.npy'))\n    if if_refine:\n        latent_train = np.load(os.path.join(log_dir, 'variables_train', 'refine_latent.npy'))\n        latent_test = np.load(os.path.join(log_dir, 'variables', 'refine_latent.npy'))\n    else:\n        latent_train = np.load(os.path.join(log_dir, 'variables_train', 'latent.npy'))\n        latent_test = np.load(os.path.join(log_dir, 'variables', 'latent.npy'))\n    phys_train = physical_variables_from_data_ids(phys_all, ids_train)\n    phys_test = physical_variables_from_data_ids(phys_all, ids_test)\n    phys_vars_list = phys_train.keys()\n\n    # remove outliers\n    is_outlier_train = np.full(latent_train.shape[0], False)\n    for p_var in phys_vars_list:\n        is_outlier_train = is_outlier_train | np.isnan(phys_train[p_var])\n    for p_var in phys_vars_list:\n        phys_train[p_var] = phys_train[p_var][~is_outlier_train]\n    latent_train = latent_train[~is_outlier_train]\n    is_outlier_test = np.full(latent_test.shape[0], False)\n    for p_var in phys_vars_list:\n        is_outlier_test = is_outlier_test | np.isnan(phys_test[p_var])\n    for p_var in phys_vars_list:\n        phys_test[p_var] = phys_test[p_var][~is_outlier_test]\n    latent_test = latent_test[~is_outlier_test]\n    print(f'Valid training samples: {latent_train.shape[0]}; test samples: {latent_test.shape[0]}.')\n\n    # only use first few principal components (optional)\n    if str(sys.argv[2]) == 'NA': \n        num_components = None\n    else:\n        num_components = int(sys.argv[2])\n        latent_train, latent_test, pca_model = pca(latent_train, latent_test, num_components, cfg.seed)\n    \n    num_samples = 3\n    sample_size = int(0.3 * latent_train.shape[0])\n       \n    train_error = {p_var:np.zeros(num_samples) for p_var in phys_vars_list}\n    test_error = {p_var:np.zeros(num_samples) for p_var in phys_vars_list}\n    reg_models = {p_var:np.empty(num_samples, dtype=object) for p_var in phys_vars_list}\n\n    rng = np.random.default_rng(cfg.seed)\n    for i in range(num_samples):\n        print('random samples #'+str(i+1))\n        samp_ids = rng.choice(latent_train.shape[0], sample_size, replace=False)\n        for p_var in phys_vars_list:\n            print('doing '+p_var+' regression...')\n            train_error[p_var][i], test_error[p_var][i], reg_models[p_var][i] = mlp_regress(latent_train[samp_ids], latent_test, phys_train[p_var][samp_ids], phys_test[p_var], cfg.seed)\n\n    if num_components == None:\n        np.save(os.path.join(log_dir, 'variables', 'regression_results.npy'), [train_error, test_error, reg_models])\n    else:\n        np.save(os.path.join(log_dir, 'variables', 'regression_results_pca_'+str(num_components)+'.npy'), [train_error, test_error, pca_model, reg_models])\n\n\nif __name__ =='__main__':\n    main()"
  },
  {
    "path": "analysis/intrinsic_dimension_estimation/__init__.py",
    "content": "import numpy as np\nfrom .methods import *\n\nclass ID_Estimator:\n    def __init__(self, method='Levina_Bickel'):\n        self.all_methods = ['Levina_Bickel', 'MiND_ML', 'MiND_KL', 'Hein', 'CD']\n        self.set_method(method)\n    \n    def set_method(self, method='Levina_Bickel'):\n        if method not in self.all_methods:\n            assert False, 'Unknown method!'\n        else:\n            self.method = method\n        \n    def fit(self, X, k_list=20, n_jobs=4):\n        if self.method in ['Hein', 'CD']:\n            dim_Hein, dim_CD = Hein_CD(X)\n            return dim_Hein if self.method=='Hein' else dim_CD\n        else:\n            if np.isscalar(k_list):\n                k_list = np.array([k_list])\n            else:\n                k_list = np.array(k_list)\n            kmax = np.max(k_list) + 2\n            dists, inds = kNN(X, kmax, n_jobs)\n            dims = []\n            for k in k_list:\n                if self.method == 'Levina_Bickel':\n                    dims.append(Levina_Bickel(X, dists, k))\n                elif self.method == 'MiND_ML':\n                    dims.append(MiND_ML(X, dists, k))\n                elif self.method == 'MiND_KL':\n                    dims.append(MiND_KL(X, dists, k))\n                else:\n                    pass\n            if len(dims) == 1:\n                return dims[0]\n            else:\n                return np.array(dims)\n    \n    def fit_all_methods(self, X, k_list=[20], n_jobs=4):\n        k_list = np.array(k_list)\n        kmax = np.max(k_list) + 2\n        dists, inds = kNN(X, kmax, n_jobs)\n        dim_all_methods = {method:[] for method in self.all_methods}\n        dim_all_methods['Hein'], dim_all_methods['CD'] = Hein_CD(X)\n        for k in k_list:\n            dim_all_methods['Levina_Bickel'].append(Levina_Bickel(X, dists, k))\n            dim_all_methods['MiND_ML'].append(MiND_ML(X, dists, k))\n            dim_all_methods['MiND_KL'].append(MiND_KL(X, dists, k))\n        for method in self.all_methods:\n            dim_all_methods[method] = np.array(dim_all_methods[method])\n        return dim_all_methods"
  },
  {
    "path": "analysis/intrinsic_dimension_estimation/matlab_codes/DANCo.m",
    "content": "function [d,kl] = DANCo(data,varargin)\n% DANCO  Estmating the intrinsic dimensionality of a dataset\n%\n% function [d,kl] = DANCo(data, ParamName,ParamValue, ...)\n%\n%  This funciton estimates the intrinsic dimensionality of a given dataset\n% by means of the DANCo algorithm described in:\n%\n%  \"DANCo: Dimensionality from Angle and Norm Concentration\"\n%  C. Ceruti, S. Bassis, A. Rozza, G. Lombardi, E. Casiraghi, P. Campadelli\n%  arXiv:1206.3881v1, http://arxiv.org/abs/1206.3881v1\n%\n% with the complexity improvement given by separating the training step\n% from the evaluation one (requires fitting models in the file DANCo_fits).\n%\n%  Parameters\n%  ----------\n% IN:\n%  data     = The data matrix (DxN) with a sample per column.\n% Named parameters:\n%  'k'\n%       The number of neighbours to be used, must be >= 1. (def=10)\n%  'minDim'\n%       The minimal number of dimensions. (def=1)\n%  'maxDim'\n%       The maximal number of dimensions. (def=size(data,1))\n%  'sigma'\n%        Standard deviation of the gaussian smoothing filter executed on\n%       the estimated divergences. (def=[]=no smoothing)\n%  'fractal'\n%       Is the fractal dimension required? (def=false)\n%  'mlMethod'\n%       Algorithm used to pre-estimate d:\n%               MLE         = Levina's algorithm.\n%               MiND_MLI    = ML on the first neighbour.\n%               MiND_MLK    = ML on the first neighbour and optimization.\n%             (def='MiND_MLK')\n%  'minCorrectionDim'\n%        Minimal dimensionality estimated with the selected mlMethod such\n%       that DANCO is adopted for correction, for lower dimensionalities\n%       the mlMethod result is produced. (def=5)\n%  'inds'\n%        Precomputed knn indices as returned by the KNN function for k \n%       set to k+2 with respect to the given k parameter. (also dists must\n%       be present)\n%  'dists'\n%        Precomputed knn distances as returned by the KNN function for k \n%       set to k+2 with respect to the given k parameter. (also inds must\n%       be present)\n%  'normalized'\n%        Are the inds/dists already normalized? Normalized here means that \n%       all the k+2 distances are normalized wrt the last one and the first\n%       (zero) and the last (one) are removed. (def=true)\n% OUT:\n%  d        = Estimated intrinsic dimensionality of the dataset.\n%  kl       = The Kullback Leibler divergences curve for d=1..D.\n%\n%  Examples\n%  --------\n% X = randsphere(30,1000,1);\n% V = linSubspSpanOrthonormalize(randn(120,30));\n% pts = V*X;\n% [inds,dists] = KNN(pts,12,true);\n% [d,kl] = DANCo(pts,'inds',inds,'dists',dists,'minDim',20,'maxDim',40);\n\n    % Checking parameters:\n    if nargin<1; error('A dataset is required'); end\n    \n    % Infos:\n    [D,N] = size(data);\n\n    % Default parameters:\n    params = struct('k',{10}, ...\n                    'minDim',{1}, 'maxDim',{D}, ...\n                    'fractal',{false}, ...\n                    'sigma',{[]},...\n                    'mlMethod',{'MiND_MLK'}, ...\n                    'normalized',{true}, ...\n                    'minCorrectionDim',{5});\n                \n    % Parsing named parameters:\n    if nargin > 1\n        try\n            params = parseParamsNamed(varargin, false, params);\n        catch e\n            error('Pairs ''name,value'' are required as arguments after the dataset.');\n        end\n    end\n    \n    % Estimating the parameters for the real data:\n    [isLowDim,dHat,mu,tau] = DANCo_statistics(data,params.k,params,nargout > 1);\n\n    % Correction based on the KL-divergence:\n    if not(isLowDim) || nargout > 1\n        % Initializing the parameters:\n        params2 = struct( ...\n            'mlMethod',{params.mlMethod}, ...\n            'minCorrectionDim',{params.minCorrectionDim});\n        numDims = params.maxDim - params.minDim + 1;\n\n        % Computing the reference parameters:\n        dHat_ref = zeros(1,numDims);\n        mu_ref = zeros(1,numDims);\n        tau_ref = zeros(1,numDims);\n        for i = 1:numDims\n            % Generating the dataset:\n            data2 = randsphere(params.minDim + i - 1, N);\n            \n            % Estimating the parameters:\n            [~,dHat_ref(i),mu_ref(i),tau_ref(i)] = ...\n                DANCo_statistics(data2,params.k,params2,true);\n        end\n\n        % Estimating the KL-divergence for all the cases d=1:D:\n        kl = DANCo_estimateKL(params.k,dHat,mu,tau,dHat_ref,mu_ref,tau_ref);\n    end\n    \n    % Smoothing if required:\n    if (not(isLowDim) || nargout > 1) && not(isempty(params.sigma))\n        % Preparing the impulse response:\n        rad = 4*ceil(params.sigma/2);\n        h = gauss(-rad:rad,params.sigma,0,false);\n        h = h/sum(h);\n\n        % Filtering:\n        kl = conv([flipdim(kl,2),kl,kl],h,'same');\n        kl = kl(numDims+1:end-numDims);\n    end\n    \n    % The dimensionality estimation:\n    if isLowDim\n        % Using the MLE estimation:\n        if params.fractal; d = dHat;\n        else d = round(dHat); end\n    else\n        % Selecting the minimum kl position:\n        [~,d] = min(kl);\n        \n        % Refining for the fractal dimension:\n        if params.fractal\n            % Fitting with a cubic smoothing spline:\n            fn = csapi(1:numDims,kl);\n            \n            % Locating the minima:\n            opts = optimset('Display','off','TolX',1e-3);\n            d = fminsearch(@(x)(fnval(fn,x)),d,opts);\n        end\n    end\n    \n    % Correcting the dimension considering the minimal one:\n    d = d + params.minDim - 1;\nend\n"
  },
  {
    "path": "analysis/intrinsic_dimension_estimation/matlab_codes/DANCoFit.m",
    "content": "function [d,kl] = DANCoFit(data,varargin)\n% DANCOFIT  Estmating the intrinsic dimensionality of a dataset\n%\n% function [d,kl] = DANCoFit(data, ParamName,ParamValue, ...)\n%\n%  This funciton estimates the intrinsic dimensionality of a given dataset\n% by means of the fast DANCo algorithm described in:\n%\n%  \"DANCo: Dimensionality from Angle and Norm Concentration\"\n%  C. Ceruti, S. Bassis, A. Rozza, G. Lombardi, E. Casiraghi, P. Campadelli\n%  arXiv:1206.3881v1, http://arxiv.org/abs/1206.3881v1\n%\n% with the complexity improvement given by separating the training step\n% from the evaluation one (requires fitting models in the file DANCo_fits).\n%\n%  Parameters\n%  ----------\n% IN:\n%  data     = The data matrix (DxN) with a sample per column.\n% Named parameters:\n%  'fractal'\n%       Is the fractal dimension required? (def=false)\n%  'mlMethod'\n%       Algorithm used to pre-estimate d:\n%               MLE         = Levina's algorithm.\n%               MiND_MLI    = ML on the first neighbour.\n%               MiND_MLK    = ML on the first neighbour and optimization.\n%             (def='MiND_MLK')\n%  'minCorrectionDim'\n%        Minimal dimensionality estimated with the selected mlMethod such\n%       that DANCO is adopted for correction, for lower dimensionalities\n%       the mlMethod result is produced. (def=5)\n%  'inds'\n%        Precomputed knn indices as returned by the KNN function for k \n%       set to k+2 with respect to the given k parameter. (also dists must\n%       be present)\n%  'dists'\n%        Precomputed knn distances as returned by the KNN function for k \n%       set to k+2 with respect to the given k parameter. (also inds must\n%       be present)\n%  'normalized'\n%        Are the inds/dists already normalized? Normalized here means that \n%       all the k+2 distances are normalized wrt the last one and the first\n%       (zero) and the last (one) are removed. (def=true)\n%  'modelfile'\n%        The name of the model file to be used. (def='DANCo_fits')\n% OUT:\n%  d        = Estimated intrinsic dimensionality of the dataset.\n%  kl       = The Kullback Leibler divergences curve for d=1..D.\n%\n%  Examples\n%  --------\n% X = randsphere(30,1000,1);\n% V = linSubspSpanOrthonormalize(randn(120,30));\n% pts = V*X;\n% [inds,dists] = KNN(pts,12,true);\n% [d,kl] = DANCoFit(pts,'inds',inds,'dists',dists);\n\n    % Checking parameters:\n    if nargin<1; error('A dataset is required'); end\n    \n    % Infos:\n    [D,N] = size(data);\n    d = 1:D;\n\n    % Default parameters:\n    params = struct('fractal',{false}, ...\n                    'mlMethod',{'MiND_MLK'}, ...\n                    'normalized',{true}, ...\n                    'minCorrectionDim',{5}, ...\n                    'modelfile',{'DANCo_fits'});\n                \n    % Parsing named parameters:\n    if nargin > 1\n        try\n            params = parseParamsNamed(varargin, false, params);\n        catch e\n            error('Pairs ''name,value'' are required as arguments after the dataset.');\n        end\n    end\n    \n    % Loading the model:\n    load(params.modelfile);\n\n    % Estimating the parameters:\n    [isLowDim,dHat,mu,tau] = DANCo_statistics(data,k,params,nargout > 1);\n\n    % Correction based on the KL-divergence:\n    if not(isLowDim) || nargout > 1\n        % Computing the reference parameters:\n        dHat_ref = feval(fitDhat,d,N);\n        mu_ref = feval(fitMu,d,N);\n        tau_ref = feval(fitTau,d,N);\n\n        % Estimating the KL-divergence for all the cases d=1:D:\n        kl = DANCo_estimateKL(k,dHat,mu,tau,dHat_ref,mu_ref,tau_ref);\n    end\n    \n    % The dimensionality estimation:\n    if isLowDim\n        % Using the MLE estimation:\n        if params.fractal; d = dHat;\n        else d = round(dHat); end\n    else\n        % Selecting the minimum kl position:\n        [~,d] = min(kl);\n        \n        % Refining for the fractal dimension:\n        if params.fractal\n            % Fitting with a cubic smoothing spline:\n            fn = csapi(1:D,kl);\n            \n            % Locating the minima:\n            opts = optimset('Display','off','TolX',1e-3);\n            d = fminsearch(@(x)(fnval(fn,x)),d,opts);\n        end\n    end\nend\n"
  },
  {
    "path": "analysis/intrinsic_dimension_estimation/matlab_codes/DANCoTrain.m",
    "content": "function DANCoTrain(modelfile,varargin)\n% DANCoTrain  Training the DANCoFit model file.\n%\n% function DANCoTrain(modelfile, ParamName,ParamValue, ...)\n%\n%  This funciton executes experiments with points uniformly drawn from\n% hyper-spheres with unitary radius to produce the fitting functions used \n% by DANCoFit, if you use this please cite:\n%\n%  \"DANCo: Dimensionality from Angle and Norm Concentration\"\n%  C. Ceruti, S. Bassis, A. Rozza, G. Lombardi, E. Casiraghi, P. Campadelli\n%  arXiv:1206.3881v1, http://arxiv.org/abs/1206.3881v1\n%\n%  Parameters\n%  ----------\n% IN:\n%  modelfile = The name of the model file to be produced.\n% Named parameters:\n%  'k'\n%       The number of neighbours to be used, must be >= 1. (def=10)\n%  'maxDim'\n%        The maximal dimensionality to be used in the experiments, more\n%       dims means more computation but also accurate fits. (def=100)\n%  'cardinalities'\n%        The list of cardinalities to be used in the experiments, more\n%       cardinalities means more computation but also accurate fits, the\n%       suggested choice is to adopt an 1-2-5 series. \n%       (def=[50,100,200,500,1000,2000,5000])\n%  'iterations'\n%        The number of iterations for experiment (results averaged), more \n%       iterations means more computation but also accurate fits. (def=100)\n%  'mlMethod'\n%       Algorithm used to pre-estimate d:\n%               MLE         = Levina's algorithm.\n%               MiND_MLI    = ML on the first neighbour.\n%               MiND_MLK    = ML on the first neighbour and optimization.\n%             (def='MiND_MLK')\n%\n%  Example\n%  -------\n% DANCoTrain('sampleModel','maxDim',20,'cardinalities',[200,500,1000],'iterations',1);\n% X = randsphere(10,500);\n% V = linSubspSpanOrthonormalize(randn(20,10));\n% pts = V*X;\n% DANCoFit(pts,'modelfile','sampleModel')\n\n    % Checking parameters:\n    if nargin<1; error('A model file name is required.'); end\n\n    % Default parameters:\n    params = struct('k',{10}, ...\n                    'maxDim',{100}, ...\n                    'cardinalities',{[50,100,200,500,1000,2000,5000]}, ...\n                    'iterations',{100}, ...\n                    'mlMethod',{'MiND_MLK'});\n                \n    % Parsing named parameters:\n    if nargin > 1\n        try\n            params = parseParamsNamed(varargin, false, params);\n        catch e\n            error('Pairs ''name,value'' are required as arguments after the dataset.');\n        end\n    end\n    \n    % Parameters to be passed to the statistics computation function:\n    params2 = struct( ...\n        'mlMethod',{params.mlMethod}, ...\n        'minCorrectionDim',{1});\n    k = params.k;\n    \n    % Computing the size of the statistic matrices:\n    N = numel(params.cardinalities);\n    D = params.maxDim;\n    \n    % Initialization:\n    dHat = zeros(D-1,N);\n    mu = zeros(D-1,N);\n    tau = zeros(D-1,N);\n    for d = 2:D\n        for n = 1:N\n            % Initializing the data of this iteration:\n            dHat_ref = zeros(1,params.iterations);\n            mu_ref = zeros(1,params.iterations);\n            tau_ref = zeros(1,params.iterations);\n\n            % Executing the experiments:\n            for i = 1:params.iterations\n                % Generating the dataset:\n                data = randsphere(d, params.cardinalities(n));\n                \n                % Sayng to the user what I'm doing:\n                fprintf('(d = %d \\tn = %d \\ti = %d)...', ...\n                    d,params.cardinalities(n),i);\n\n                % Estimating the parameters:\n                [~,dHat_ref(i),mu_ref(i),tau_ref(i)] = ...\n                    DANCo_statistics(data,k,params2,true);\n                \n                % Sayng to the user what I'm doing:\n                fprintf(' \\tDONE!\\n');\n            end\n            \n            % Storing the results:\n            dHat(d-1,n) = mean(dHat_ref);\n            mu(d-1,n) = mean(mu_ref);\n            tau(d-1,n) = mean(tau_ref);\n        end\n    end\n    \n    % Sayng to the user what I'm doing:\n    fprintf('Fitting & saving...');\n    \n    % Initializing:\n    inVars = {2:D,params.cardinalities};\n                \n    % Fitting the funcitons:\n    fitDhatFun = fnxtr(csaps(inVars,dHat),2);\n    fitMuFun = fnxtr(csaps(inVars,mu),2);\n    fitTauFun = fnxtr(csaps(inVars,tau),2);\n    \n    % Producing function handles:\n    fitDhat = @(D,N)(fnval(fitDhatFun,{D,N})');\n    fitMu = @(D,N)(fnval(fitMuFun,{D,N})');\n    fitTau = @(D,N)(fnval(fitTauFun,{D,N})');\n    \n    % Writing the model:\n    save(modelfile,'fitMu','fitTau','fitDhat','k');\n    \n    % Sayng to the user what I'm doing:\n    fprintf(' \\tDONE!\\n');\n\nend\n"
  },
  {
    "path": "analysis/intrinsic_dimension_estimation/matlab_codes/KNN.m",
    "content": "function [inds,dists] = KNN(X,k,normalized)\n%KNN  A multi-version knnsearch\n%\n% function [inds,dists] = KNN(X,k)\n%\n%  This function allows to identify the indexes (and optionally the\n% distances) of the first k neighbours of each given columnwise point.\n%\n% IN:\n%  X        = The dataset, each column represents a point.\n%  k        = The number of neighbours of interest. (def=10)\n%  normalized = Must the k neighbours be normalized? \n%           (triggers a k+2 search, def=false)\n% OUT:\n%  inds     = Each row i-th contains the neighbour indexes of the i-th point.\n%  dists    = Each row i-th contains the neighbour distances of the i-th point.\n\n    % Checking parameters:\n    if nargin<1; error('Dataset required.'); end\n    if nargin<2; k=10; end\n    if nargin<3; normalized=false; end\n\n    % Managing th normalized:\n    if normalized; k=k+2; end\n    \n    % Checking the version:\n    if datenum(version('-date')) < 734729\n        % Computing all the distances:\n        dists = L2(X, X);\n        \n        % Sorting:\n        [dists, inds] = sort(dists, 2);\n        \n        % Choosing the first k elements:\n        dists = dists(:,1:k);\n        inds = inds(:,1:k);\n    else\n        % The Matlab implementation:\n        X = X';\n        [inds, dists] = knnsearch(X,X,'K',k);\n    end\n    \n    % Managing th normalized:\n    if normalized\n        % Cropping both inds and dists:\n        inds = inds(:,2:end-1);\n        dists = dists(:,2:end);\n        \n        % Normalizing and removing the last element:\n        dists = dists(:,1:end-1)./repmat(dists(:,end),[1,size(dists,2)-1]);\n    end\nend\n\n% ------------------------ LOCAL FUNCTIONS ------------------------\n\n% L2 distance (by Roland Bunschoten):\nfunction d = L2(a,b)\n    d = real(sqrt(bsxfun(@plus, sum(a .* a)', bsxfun(@minus, sum(b .* b), 2 * a' * b))));\nend\n"
  },
  {
    "path": "analysis/intrinsic_dimension_estimation/matlab_codes/MLE.m",
    "content": "function [d,ds] = MLE(data,varargin)\n%MLE  Intrinsic dimensionality using the Levina's MLE technique\n%\n% function [d,ds] = MLE(data, ParamName,ParamValue, ...)\n%\n%  This funciton estimates the intrinsic dimensionality of a dataset using \n% the Levina's Maximum Likelihood Estimator (MLE).\n%\n%  Parameters\n%  ----------\n% IN:\n%  data     = Matrix DxN containing the dataset as columnwise points.\n% Named parameters:\n%  'k'\n%       The number of neighbours to be used, must be >= 1. (def=10)\n%  'dists'\n%        Precomputed knn distanced as returned by the KNN function for k \n%       set to k+2 with respect to the given k parameter (k is set to \n%       \"size(dists,2)-2\" if knn is present).\n%  'normalized'\n%        Are the dists already normalized? Normalized here means that all\n%       the k+2 distances are normalized wrt the last one and the first\n%       (zero) and the last (one) are removed. (def=true)\n% OUT:\n%  d    = Estimated intrinsic dimensionality of the dataset.\n%  ds   = Local estimations.\n%\n%  Examples\n%  --------\n% X = randsphere(30,1000,1);\n% V = linSubspSpanOrthonormalize(randn(120,30));\n% pts = V*X;\n% [inds,dists] = KNN(pts,12,true);\n% d = MLE(pts,'dists',dists);\n\n    % Checking parameters:\n    if nargin<1; error('A dataset is required'); end\n    \n    % Infos:\n    N = size(data,2);\n    \n    % Default parameters:\n    params = struct('k',{10},'normalized',{true});\n    \n    % Parsing named parameters:\n    if nargin > 1\n        try\n            params = parseParamsNamed(varargin, false, params);\n        catch e\n            error('Pairs ''name,value'' are required as arguments after the dataset.');\n        end\n    end\n\n    % Is the knn already there?\n    mustDoKnn = true;\n    if isfield(params,'dists')\n        % Extracting and checking:\n        dists = params.dists;\n        if size(dists,1) == N && size(dists,2) > 2\n            % KNN already present:\n            params.k = size(dists,2)-2;\n            mustDoKnn = false;\n        end\n    end\n    \n    % Finding the KNNs:\n    if mustDoKnn\n        if not(isfield(params,'k'))\n            error('One parameter of ''dists'' or ''k'' is required.');\n        end\n        [~,dists] = KNN(data,params.k,true);\n    end\n    \n    % Normalizing the distances:\n    if not(mustDoKnn || params.normalized)\n        % Cropping the dists:\n        dists = dists(:,2:end);\n        \n        % Normalizing and removing the last element:\n        dists = dists(:,1:end-1)./repmat(dists(:,end),[1,size(dists,2)-1]);\n    end\n\n    \n    % Local estimation:\n    ds = -1./mean(log(dists),2);\n    \n    % Harmonic mean:\n    d = 1./mean(1./ds);\nend\n"
  },
  {
    "path": "analysis/intrinsic_dimension_estimation/matlab_codes/MiND_KL.m",
    "content": "function [d,kl] = MiND_KL(data,varargin)\n%MiND_KL  Intrinsic dimensionality using the MiND_KL technique\n%\n% function [d,kl] = MiND_KL(data, ParamName,ParamValue, ...)\n%\n%  This funciton estimates the intrinsic dimensionality of a dataset using \n% the Minimum Neighbor Distance estimators described in:\n%\n% \"Minimum Neighbor Distance Estimators of Intrinsic Dimension\",\n%  A.Rozza, G.Lombardi, C.Ceruti, E.Casiraghi, P.Campadelli,\n%  Published to the European Conference of Machine Learning (ECML 2011).\n%\n%  Parameters\n%  ----------\n% IN:\n%  data     = Matrix DxN containing the dataset as columnwise points.\n% Named parameters:\n%  'k'\n%       The number of neighbours to be used, must be >= 1. (def=10)\n%  'minDim'\n%       The minimal number of dimensions. (def=1)\n%  'maxDim'\n%       The maximal number of dimensions. (def=size(data,1))\n%  'sigma'\n%        Standard deviation of the gaussian smoothing filter executed on\n%       the estimated divergences. (def=[]=no smoothing)\n%  'dists'\n%        Precomputed knn distanced as returned by the KNN function for k \n%       set to k+2 with respect to the given k parameter (k is set to \n%       \"size(dists,2)-2\" if knn is present).\n%  'normalized'\n%        Are the dists already normalized? Normalized here means that all\n%       the k+2 distances are normalized wrt the last one and the first\n%       (zero) and the last (one) are removed. (def=true)\n% OUT:\n%  d    = Estimated intrinsic dimensionality of the dataset.\n%  kl   = The Kullback Leibler divergences curve for d=1..D.\n%\n%  Examples\n%  --------\n% X = randsphere(30,1000,1);\n% V = linSubspSpanOrthonormalize(randn(120,30));\n% pts = V*X;\n% [inds,dists] = KNN(pts,12,true);\n% [d,kl] = MiND_KL(pts,'dists',dists);\n\n    % Checking parameters:\n    if nargin<1; error('A dataset is required'); end\n    \n    % Infos:\n    [D,N] = size(data);\n    \n    % Default parameters:\n    params = struct('k',{10}, ...\n                    'normalized',{true}, ...\n                    'minDim',{1}, 'maxDim',{D}, ...\n                    'sigma',[]);\n                \n    % Parsing named parameters:\n    if nargin > 1\n        try\n            params = parseParamsNamed(varargin, false, params);\n        catch e\n            error('Pairs ''name,value'' are required as arguments after the dataset.');\n        end\n    end\n\n    % Is the knn already there?\n    mustDoKnn = true;\n    if isfield(params,'dists')\n        % Extracting and checking:\n        dists = params.dists;\n        if size(dists,1) == N && size(dists,2) > 2\n            % KNN already present:\n            params.k = size(dists,2)-2;\n            mustDoKnn = false;\n        end\n    end\n    \n    % Finding the KNNs:\n    if mustDoKnn\n        if not(isfield(params,'k'))\n            error('One parameter of ''dists'' or ''k'' is required.');\n        end\n        [~,dists] = KNN(data,params.k+2);\n    end\n    \n    % Normalizing the distances:\n    if mustDoKnn || not(params.normalized)\n        % Cropping the dists:\n        dists = dists(:,2:end);\n        \n        % Normalizing and removing the last element:\n        dists = dists(:,1:end-1)./repmat(dists(:,end),[1,size(dists,2)-1]);\n    end\n    \n    % Only the first normalized neighbour is required:\n    k = size(dists,2);\n    dists = dists(:,1)';\n\n    % Executing the experiments with unit uniformly-sampled hyper-spheres:\n    numDims = params.maxDim - params.minDim + 1;\n    kl = zeros(1,numDims);\n    for i=1:numDims\n        % Generating the data:\n        data2 = randsphere(params.minDim + i - 1, N);\n        \n        % Computing the normalized knn distances:\n        [~,distsExp] = KNN(data2,k,true);\n        distsExp = distsExp(:,1)';\n        \n        % Estimating the Kullback-Leibler divergence:\n        kl(i) = wangKL1d(distsExp,dists);\n    end\n    \n    % Smoothing if required:\n    if not(isempty(params.sigma))\n        % Preparing the impulse response:\n        rad = 4*ceil(params.sigma/2);\n        h = gauss(-rad:rad,params.sigma,0,false);\n        h = h/sum(h);\n        \n        % Filtering:\n        kl = conv([flipdim(kl,2),kl,kl],h,'same');\n        kl = kl(numDims+1:end-numDims);\n    end\n\n    % Estimating the id:\n    [~,d] = min(kl);\n    d = d + params.minDim - 1;\nend\n\n% ------------------------ LOCAL FUNCTIONS ------------------------\n\n% Wang's Kullback-Leibler divergence estimator in one dimension:\nfunction div = wangKL1d(data1,data2)\n    % Getting data:\n    n = numel(data1);\n    m = numel(data2);\n    data1 = sort(data1);\n    \n    % Generating the nearest neighbors for data1:\n    nnData1 = [-inf,data1,inf];\n    nnData1 = abs(nnData1(1:end-1)-nnData1(2:end));\n    nnData1 = min([nnData1(1:end-1);nnData1(2:end)],[],1);\n    \n    % Generating the nearest neighbors for data2:\n    nnData2 = min(abs(repmat(data1,[m,1]) - repmat(data2',[1,n])),[],1);\n    \n    % Computing the kl div:\n    div = abs(sum(log(nnData2./(nnData1+eps)))/n + log(m/(n-1)));\nend\n\n"
  },
  {
    "path": "analysis/intrinsic_dimension_estimation/matlab_codes/MiND_ML.m",
    "content": "function d = MiND_ML(data,varargin)\n%MiND_ML  Intrinsic dimensionality using the MiND_ML techniques\n%\n% function d = MiND_ML(data, ParamName,ParamValue, ...)\n%\n%  This funciton estimates the intrinsic dimensionality of a dataset using \n% the Minimum Neighbor Distance estimators described in:\n%\n% \"Minimum Neighbor Distance Estimators of Intrinsic Dimension\",\n%  A.Rozza, G.Lombardi, C.Ceruti, E.Casiraghi, P.Campadelli,\n%  Published to the European Conference of Machine Learning (ECML 2011).\n%\n%  Parameters\n%  ----------\n% IN:\n%  data     = Matrix DxN containing the dataset as columnwise points.\n% Named parameters:\n%  'k'\n%       The number of neighbours to be used, must be >= 1. (def=10)\n%  'optimize'\n%        Must an optimization step be executed to estimate a fractal id,\n%       that is MiND_MLK instead of MiND_MLI? (def=false)\n%  'dists'\n%        Precomputed knn distanced as returned by the KNN function for k \n%       set to k+2 with respect to the given k parameter (k is set to \n%       \"size(dists,2)-2\" if knn is present).\n%  'normalized'\n%        Are the dists already normalized? Normalized here means that all\n%       the k+2 distances are normalized wrt the last one and the first\n%       (zero) and the last (one) are removed. (def=true)\n% OUT:\n%  d    = Estimated intrinsic dimensionality of the dataset.\n%\n%  Examples\n%  --------\n% X = randsphere(30,1000,1);\n% V = linSubspSpanOrthonormalize(randn(120,30));\n% pts = V*X;\n% [inds,dists] = KNN(pts,12,true);\n% d = MiND_ML(pts,'dists',dists);\n\n    % Checking parameters:\n    if nargin<1; error('A dataset is required'); end\n    \n    % Infos:\n    [D,N] = size(data);\n    \n    % Default parameters:\n    params = struct('k',{10},'optimize',{false},'normalized',{true});\n    \n    % Parsing named parameters:\n    if nargin > 1\n        try\n            params = parseParamsNamed(varargin, false, params);\n        catch e\n            error('Pairs ''name,value'' are required as arguments after the dataset.');\n        end\n    end\n\n    % Is the knn already there?\n    mustDoKnn = true;\n    if isfield(params,'dists')\n        % Extracting and checking:\n        dists = params.dists;\n        if size(dists,1) == N && size(dists,2) > 2\n            % KNN already present:\n            params.k = size(dists,2)-2;\n            mustDoKnn = false;\n        end\n    end\n    \n    % Finding the KNNs:\n    if mustDoKnn\n        if not(isfield(params,'k'))\n            error('One parameter of ''dists'' or ''k'' is required.');\n        end\n        [~,dists] = KNN(data,params.k,true);\n    end\n    \n    % Normalizing the distances:\n    if not(mustDoKnn || params.normalized)\n        % Cropping the dists:\n        dists = dists(:,2:end);\n        \n        % Normalizing and removing the last element:\n        dists = dists(:,1:end-1)./repmat(dists(:,end),[1,size(dists,2)-1]);\n    end\n\n    \n    % Infos:\n    [n,k] = size(dists);\n    \n    % The first normalized distance:\n    r = dists(:,1);\n    \n    % The log likelihood derivative squared:\n    fun = @(d)((n/d)+sum(log(r)-((k-1).*(log(r).*r.^d))./(1-r.^d)))^2;\n    \n    % Evaluating it:\n    vals = zeros(1,D);\n    for d=1:D; vals(d)=fun(d); end\n    \n    % MiND_MLI result:\n    [~,d] = min(vals);\n    \n    % The optimization:\n    if params.optimize\n        opts = optimset('Display','off','TolX',1e-3);\n        d = fminsearch(fun,d,opts);\n    end\nend\n"
  },
  {
    "path": "analysis/intrinsic_dimension_estimation/matlab_codes/demo_idEstimation.m",
    "content": "%% Intrinsic Dimensionality (id) estimation\n%\n% In the past decade the development of automatic techniques to estimate \n% the intrinsic dimensionality of a given dataset has gained considerable \n% attention due to its relevance in several application elds. \n%\n% In this small toolbox some of the state-of-art techniques are implemented\n% as functions that use a \"standard\" interface to use them, thus allowing\n% client code to be decoupled from the knowledge of the applied technique.\n% \n% For more details see/cite:\n%\n%  \"Minimum Neighbor Distance Estimators of Intrinsic Dimension\",\n%  A.Rozza, G.Lombardi, C.Ceruti, E.Casiraghi, P.Campadelli,\n%  Published to the European Conference of Machine Learning (ECML 2011).\n%\n%  \"DANCo: Dimensionality from Angle and Norm Concentration\",\n%  C. Ceruti, S. Bassis, A. Rozza, G. Lombardi, E. Casiraghi, P. Campadelli\n%  arXiv:1206.3881v1, http://arxiv.org/abs/1206.3881v1\n\n%% Usage of this id-estimation toolbox\n%\n% As an example here we generate a simple dataset and test all the \n% implemented techniques on it to compare the results, the dataset is an \n% hypersphere in R^{10} linearly embedded in R^{20}.\n\n% Generating the dataset:\nd = 10; D = 20;\nX = randsphere(d,500);\nV = linSubspSpanOrthonormalize(randn(D,d));\ndata = V*X;\n\n% A first base estimation:\nfprintf('MLE      => %2.2f\\n', MLE(data));\nfprintf('DANCoFit => %2.2f\\n', DANCoFit(data));\n\n%%\n\n% The techniques:\nestimators = {'MLE','MiND_ML','MiND_KL','DANCo','DANCoFit'};\n\n% Initializing the results:\nidEst = zeros(1,numel(estimators));\nspentTime = zeros(1,numel(estimators));\n\n% Running the experiments:\nfor i = 1:numel(estimators)\n    % Obtaining the funciton:\n    estimator = str2func(estimators{i});\n    \n    % Estimating:\n    tic;\n    idEst(i) = estimator(data);\n    spentTime(i) = toc;\nend\n\n% Plotting the results:\nfigure; bar([idEst(:),spentTime(:)]); hold on;\nplot([0,numel(estimators)+1],[d,d],'b:');\nset(gca,'xticklabel',estimators);\nlegend({'Estimated id','Spent time'},'Location','NorthWest');\ntitle(sprintf('Estimating the id of a %dd hypersphere linearly embedded in R^{%d}',d,D));\n\n%% Estimators based on the Kullback-Leibler divergence\n%\n% Some of the presented algorithms estimate the Kullback-Leibler divergence\n% between the distribution of some statistics computed on the real dataset\n% and those computed on synthetically-cenerated data (hyperspheres) or by\n% means of fitting functions (trained using the statistics produced by\n% synthetic data). Here the results obtained by some of them.\n\n% Initialization:\nestimators = {'MiND_KL','DANCo','DANCoFit'};\n\n% Initializing the results:\nidEst = zeros(1,numel(estimators));\nspentTime = zeros(1,numel(estimators));\nkls = zeros(numel(estimators),D);\n\n% Running the experiments:\nfor i = 1:numel(estimators)\n    % Obtaining the funciton:\n    estimator = str2func(estimators{i});\n    \n    % Estimating:\n    tic;\n    [idEst(i),kls(i,:)] = estimator(data);\n    spentTime(i) = toc;\nend\n\n% Notice that DANCo produces unreliable results for d=1:\nDs = 2:D;\nkls2 = kls(:,Ds);\n\n% Plotting the results:\nfigure; plot(Ds,kls2'); hold on; \nplot([idEst;idEst],[zeros(size(idEst));max(kls2(:))*ones(size(idEst))],':');\nlegend(estimators,'Location','NorthEast');\ntitle('KL-based id estimatros: the estimated KLs');\n"
  },
  {
    "path": "analysis/intrinsic_dimension_estimation/matlab_codes/gauss.m",
    "content": "function func = gauss(x, sigma, mu, norm)\n% GAUSS  Returns the monodimensional gaussian function\n%\n%  Return the evaluation of the gaussian function on an array of points\n%\n%  Params:\n%\n%  x:       The array of x points (def=[-10:10])\n%  sigma:   The standard deviation (def=2.5)\n%  mu:      The mean of th e gaussian (def=0)\n%  norm:    Must be normalized? (def=true)\n\n% Check x\nif nargin<1\n    x = -10:10;\nend\n\n% Check sigma\nif nargin<2\n    sigma = 2.5;\nend\n\n% Check mean\nif nargin<3\n    mu = 0;\nend\n\n% Check norm\nif nargin<4\n    norm = true;\nend\n\n% Computing the gaussian funcion\nif norm\n    func = exp(-(x-mu).^2/(2*sigma^2))/(sigma*sqrt(2*pi));\nelse\n    func = exp(-(x-mu).^2/(2*sigma^2));\nend\n"
  },
  {
    "path": "analysis/intrinsic_dimension_estimation/matlab_codes/html/demo_idEstimation.html",
    "content": "\n<!DOCTYPE html\n  PUBLIC \"-//W3C//DTD HTML 4.01 Transitional//EN\">\n<html><head>\n      <meta http-equiv=\"Content-Type\" content=\"text/html; charset=utf-8\">\n   <!--\nThis HTML was auto-generated from MATLAB code.\nTo make changes, update the MATLAB code and republish this document.\n      --><title>Intrinsic Dimensionality (id) estimation</title><meta name=\"generator\" content=\"MATLAB 7.13\"><link rel=\"schema.DC\" href=\"http://purl.org/dc/elements/1.1/\"><meta name=\"DC.date\" content=\"2013-01-31\"><meta name=\"DC.source\" content=\"demo_idEstimation.m\"><style type=\"text/css\">\n\nbody {\n  background-color: white;\n  margin:10px;\n}\n\nh1 {\n  color: #990000; \n  font-size: x-large;\n}\n\nh2 {\n  color: #990000;\n  font-size: medium;\n}\n\n/* Make the text shrink to fit narrow windows, but not stretch too far in \nwide windows. */ \np,h1,h2,div.content div {\n  max-width: 600px;\n  /* Hack for IE6 */\n  width: auto !important; width: 600px;\n}\n\npre.codeinput {\n  background: #EEEEEE;\n  padding: 10px;\n}\n@media print {\n  pre.codeinput {word-wrap:break-word; width:100%;}\n} \n\nspan.keyword {color: #0000FF}\nspan.comment {color: #228B22}\nspan.string {color: #A020F0}\nspan.untermstring {color: #B20000}\nspan.syscmd {color: #B28C00}\n\npre.codeoutput {\n  color: #666666;\n  padding: 10px;\n}\n\npre.error {\n  color: red;\n}\n\np.footer {\n  text-align: right;\n  font-size: xx-small;\n  font-weight: lighter;\n  font-style: italic;\n  color: gray;\n}\n\n  </style></head><body><div class=\"content\"><h1>Intrinsic Dimensionality (id) estimation</h1><!--introduction--><p>In the past decade the development of automatic techniques to estimate the intrinsic dimensionality of a given dataset has gained considerable attention due to its relevance in several application elds.</p><p>In this small toolbox some of the state-of-art techniques are implemented as functions that use a \"standard\" interface to use them, thus allowing client code to be decoupled from the knowledge of the applied technique.</p><p>For more details see/cite:</p><pre>\"Minimum Neighbor Distance Estimators of Intrinsic Dimension\",\nA.Rozza, G.Lombardi, C.Ceruti, E.Casiraghi, P.Campadelli,\nPublished to the European Conference of Machine Learning (ECML 2011).</pre><pre>\"DANCo: Dimensionality from Angle and Norm Concentration\",\nC. Ceruti, S. Bassis, A. Rozza, G. Lombardi, E. Casiraghi, P. Campadelli\narXiv:1206.3881v1, http://arxiv.org/abs/1206.3881v1</pre><!--/introduction--><h2>Contents</h2><div><ul><li><a href=\"#1\">Usage of this id-estimation toolbox</a></li><li><a href=\"#3\">Estimators based on the Kullback-Leibler divergence</a></li></ul></div><h2>Usage of this id-estimation toolbox<a name=\"1\"></a></h2><p>As an example here we generate a simple dataset and test all the implemented techniques on it to compare the results, the dataset is an hypersphere in R^{10} linearly embedded in R^{20}.</p><pre class=\"codeinput\"><span class=\"comment\">% Generating the dataset:</span>\nd = 10; D = 20;\nX = randsphere(d,500);\nV = linSubspSpanOrthonormalize(randn(D,d));\ndata = V*X;\n\n<span class=\"comment\">% A first base estimation:</span>\nfprintf(<span class=\"string\">'MLE      =&gt; %2.2f\\n'</span>, MLE(data));\nfprintf(<span class=\"string\">'DANCoFit =&gt; %2.2f\\n'</span>, DANCoFit(data));\n</pre><pre class=\"codeoutput\">MLE      =&gt; 8.07\nDANCoFit =&gt; 10.00\n</pre><pre class=\"codeinput\"><span class=\"comment\">% The techniques:</span>\nestimators = {<span class=\"string\">'MLE'</span>,<span class=\"string\">'MiND_ML'</span>,<span class=\"string\">'MiND_KL'</span>,<span class=\"string\">'DANCo'</span>,<span class=\"string\">'DANCoFit'</span>};\n\n<span class=\"comment\">% Initializing the results:</span>\nidEst = zeros(1,numel(estimators));\nspentTime = zeros(1,numel(estimators));\n\n<span class=\"comment\">% Running the experiments:</span>\n<span class=\"keyword\">for</span> i = 1:numel(estimators)\n    <span class=\"comment\">% Obtaining the funciton:</span>\n    estimator = str2func(estimators{i});\n\n    <span class=\"comment\">% Estimating:</span>\n    tic;\n    idEst(i) = estimator(data);\n    spentTime(i) = toc;\n<span class=\"keyword\">end</span>\n\n<span class=\"comment\">% Plotting the results:</span>\nfigure; bar([idEst(:),spentTime(:)]); hold <span class=\"string\">on</span>;\nplot([0,numel(estimators)+1],[d,d],<span class=\"string\">'b:'</span>);\nset(gca,<span class=\"string\">'xticklabel'</span>,estimators);\nlegend({<span class=\"string\">'Estimated id'</span>,<span class=\"string\">'Spent time'</span>},<span class=\"string\">'Location'</span>,<span class=\"string\">'NorthWest'</span>);\ntitle(sprintf(<span class=\"string\">'Estimating the id of a %dd hypersphere linearly embedded in R^{%d}'</span>,d,D));\n</pre><img vspace=\"5\" hspace=\"5\" src=\"demo_idEstimation_01.png\" alt=\"\"> <h2>Estimators based on the Kullback-Leibler divergence<a name=\"3\"></a></h2><p>Some of the presented algorithms estimate the Kullback-Leibler divergence between the distribution of some statistics computed on the real dataset and those computed on synthetically-cenerated data (hyperspheres) or by means of fitting functions (trained using the statistics produced by synthetic data). Here the results obtained by some of them.</p><pre class=\"codeinput\"><span class=\"comment\">% Initialization:</span>\nestimators = {<span class=\"string\">'MiND_KL'</span>,<span class=\"string\">'DANCo'</span>,<span class=\"string\">'DANCoFit'</span>};\n\n<span class=\"comment\">% Initializing the results:</span>\nidEst = zeros(1,numel(estimators));\nspentTime = zeros(1,numel(estimators));\nkls = zeros(numel(estimators),D);\n\n<span class=\"comment\">% Running the experiments:</span>\n<span class=\"keyword\">for</span> i = 1:numel(estimators)\n    <span class=\"comment\">% Obtaining the funciton:</span>\n    estimator = str2func(estimators{i});\n\n    <span class=\"comment\">% Estimating:</span>\n    tic;\n    [idEst(i),kls(i,:)] = estimator(data);\n    spentTime(i) = toc;\n<span class=\"keyword\">end</span>\n\n<span class=\"comment\">% Notice that DANCo produces unreliable results for d=1:</span>\nDs = 2:D;\nkls2 = kls(:,Ds);\n\n<span class=\"comment\">% Plotting the results:</span>\nfigure; plot(Ds,kls2'); hold <span class=\"string\">on</span>;\nplot([idEst;idEst],[zeros(size(idEst));max(kls2(:))*ones(size(idEst))],<span class=\"string\">':'</span>);\nlegend(estimators,<span class=\"string\">'Location'</span>,<span class=\"string\">'NorthEast'</span>);\ntitle(<span class=\"string\">'KL-based id estimatros: the estimated KLs'</span>);\n</pre><img vspace=\"5\" hspace=\"5\" src=\"demo_idEstimation_02.png\" alt=\"\"> <p class=\"footer\"><br>\n      Published with MATLAB&reg; 7.13<br></p></div><!--\n##### SOURCE BEGIN #####\n%% Intrinsic Dimensionality (id) estimation\n%\n% In the past decade the development of automatic techniques to estimate \n% the intrinsic dimensionality of a given dataset has gained considerable \n% attention due to its relevance in several application elds. \n%\n% In this small toolbox some of the state-of-art techniques are implemented\n% as functions that use a \"standard\" interface to use them, thus allowing\n% client code to be decoupled from the knowledge of the applied technique.\n% \n% For more details see/cite:\n%\n%  \"Minimum Neighbor Distance Estimators of Intrinsic Dimension\",\n%  A.Rozza, G.Lombardi, C.Ceruti, E.Casiraghi, P.Campadelli,\n%  Published to the European Conference of Machine Learning (ECML 2011).\n%\n%  \"DANCo: Dimensionality from Angle and Norm Concentration\",\n%  C. Ceruti, S. Bassis, A. Rozza, G. Lombardi, E. Casiraghi, P. Campadelli\n%  arXiv:1206.3881v1, http://arxiv.org/abs/1206.3881v1\n\n%% Usage of this id-estimation toolbox\n%\n% As an example here we generate a simple dataset and test all the \n% implemented techniques on it to compare the results, the dataset is an \n% hypersphere in R^{10} linearly embedded in R^{20}.\n\n% Generating the dataset:\nd = 10; D = 20;\nX = randsphere(d,500);\nV = linSubspSpanOrthonormalize(randn(D,d));\ndata = V*X;\n\n% A first base estimation:\nfprintf('MLE      => %2.2f\\n', MLE(data));\nfprintf('DANCoFit => %2.2f\\n', DANCoFit(data));\n\n%%\n\n% The techniques:\nestimators = {'MLE','MiND_ML','MiND_KL','DANCo','DANCoFit'};\n\n% Initializing the results:\nidEst = zeros(1,numel(estimators));\nspentTime = zeros(1,numel(estimators));\n\n% Running the experiments:\nfor i = 1:numel(estimators)\n    % Obtaining the funciton:\n    estimator = str2func(estimators{i});\n    \n    % Estimating:\n    tic;\n    idEst(i) = estimator(data);\n    spentTime(i) = toc;\nend\n\n% Plotting the results:\nfigure; bar([idEst(:),spentTime(:)]); hold on;\nplot([0,numel(estimators)+1],[d,d],'b:');\nset(gca,'xticklabel',estimators);\nlegend({'Estimated id','Spent time'},'Location','NorthWest');\ntitle(sprintf('Estimating the id of a %dd hypersphere linearly embedded in R^{%d}',d,D));\n\n%% Estimators based on the Kullback-Leibler divergence\n%\n% Some of the presented algorithms estimate the Kullback-Leibler divergence\n% between the distribution of some statistics computed on the real dataset\n% and those computed on synthetically-cenerated data (hyperspheres) or by\n% means of fitting functions (trained using the statistics produced by\n% synthetic data). Here the results obtained by some of them.\n\n% Initialization:\nestimators = {'MiND_KL','DANCo','DANCoFit'};\n\n% Initializing the results:\nidEst = zeros(1,numel(estimators));\nspentTime = zeros(1,numel(estimators));\nkls = zeros(numel(estimators),D);\n\n% Running the experiments:\nfor i = 1:numel(estimators)\n    % Obtaining the funciton:\n    estimator = str2func(estimators{i});\n    \n    % Estimating:\n    tic;\n    [idEst(i),kls(i,:)] = estimator(data);\n    spentTime(i) = toc;\nend\n\n% Notice that DANCo produces unreliable results for d=1:\nDs = 2:D;\nkls2 = kls(:,Ds);\n\n% Plotting the results:\nfigure; plot(Ds,kls2'); hold on; \nplot([idEst;idEst],[zeros(size(idEst));max(kls2(:))*ones(size(idEst))],':');\nlegend(estimators,'Location','NorthEast');\ntitle('KL-based id estimatros: the estimated KLs');\n\n##### SOURCE END #####\n--></body></html>"
  },
  {
    "path": "analysis/intrinsic_dimension_estimation/matlab_codes/linSubspSpanOrthonormalize.m",
    "content": "function Vo = linSubspSpanOrthonormalize(V)\n% LINSUBSPSPANORTHONORMALIZE  Orthonormalizes a linear subspace span matrix.\n%\n% function Vo = linSubspSpanOrthonormalize(V)\n%\n%  This funciton takes a DxN matrix that represents a set of N linear\n% independent vectors of a D-dimensional space, and that identifies an\n% N-dimensional linear subspace, and orthonormalizes them so that the\n% subspace spanned is the same but the vectors are versors orthogonal one\n% each other. The solutions to this problem are infinitely many, the one\n% given is the one that normalizes the first vector, orthonormalizes the\n% second with respect to the plane containing the first two vectors,\n% orthonormalizes the third with respect to the 3D space spanned by the\n% first three vectors and so on. Obviously must be 1<=N<=D.\n%\n%  Parameters\n%  ----------\n% IN:\n%  V    = The DxN matrix containing the N vectors.\n% OUT:\n%  Vo   = The orthonormalized linear subspace base.\n%\n%  Pre\n%  ---\n% -  The columns of V must be linear independent (i.e. rank(V')==N).\n% -  The columns of V must be >=1 and <=D.\n%\n%  Post\n%  ----\n% -  The returned Vo contains columns that are normal\n%   (i.e. norm(V(:,i)==1).\n% -  The returned Vo contains columns that are orthogonal\n%   (i.e. dot(V(:,i),V(:,j))==0 for i~=j).\n\n% Check params:\nif size(V,2)<1 || size(V,2)>size(V,1)\n    error('N columns must be >=1 and <=D rows!');\nend\n\n% Check for linearly dependence:\nif rank(V')~=size(V,2)\n    error('The V vectors must be linearly independent!');\nend\n\n% Normalizing the V vectors:\nnormV = sqrt(sum(V.^2));\nnormV(normV==0) = eps;\nV = V./repmat(normV,[size(V,1),1]);\n\n% Orthogonalizing:\nfor i=2:size(V,2)\n    % Computing the projection on the subspace:\n    proj = V(:,1)*dot(V(:,1),V(:,i));\n    for j=2:i-1\n        % The projection of the Vi vector over the Vj orthonormal versor:\n        proj = proj + V(:,j)*dot(V(:,j),V(:,i));\n    end\n    % Removing and normalizing:\n    V(:,i) = V(:,i) - proj;\n    V(:,i) = V(:,i)./norm(V(:,i));\nend\n\n% Returning:\nVo = V;\n"
  },
  {
    "path": "analysis/intrinsic_dimension_estimation/matlab_codes/parseParamsNamed.m",
    "content": "function pars = parseParamsNamed(parsCell,recurse,parsi,multi)\n% PARSEPARAMSNAMED  Getting named parameters as a structure.\n%\n% function pars = parseParamsNamed(parsCell,recurse,parsi,multi)\n%\n%  This function allows to get named parameters as a struct object. Named\n% parameters are sequences of strings and values pair where the string\n% represents the name of the parameter and the value its value.\n%\n%  Parameters\n%  ----------\n% IN:\n%  parsCell     = The cell containing the parameters (like varargin).\n%  recurse      = Must the parsing recurse in sub-cells? (def=false)\n%  parsi        = Initial struct (with default values). (def=struct)\n%  multi        = Must the multi-valued fields be grouped in a cell? (def=false)\n% OUT:\n%  pars         = The structure containing the parameters.\n\n% Chck parameters:\nif nargin<1 || ~isa(parsCell,'cell')\n    error('A cell with parameters must be passed as parameter!');\nend\nif nargin<2; recurse=false; end\nif nargin<3; parsi=struct; end\nif nargin<4; multi=false; end\n\n% Check for the size:\nN = numel(parsCell);\nif mod(N,2)~=0\n    error('The cell must contain an even number of elements!');\nend\n\n% Init:\npars = parsi;\n\n% Iterating on parameters:\nfor i=1:2:N\n    % Get the name:\n    name = parsCell{i};\n    if ~isa(name,'char')\n        error('The name of parameters must be a string!');\n    end\n    \n    % Getting the value:\n    value = parsCell{i+1};\n    if recurse && isa(value,'cell')\n        % Generating the substructure:\n        value = parseParamsNamed(value,recurse);\n    end\n    \n    % Checking if multi:\n    if multi\n        % Init:\n        if not(isfield(pars,name))\n            pars.(name) = {};\n        end\n        if not(isa(pars.(name),'cell'))\n            pars.(name) = {pars.(name)};\n        end\n        \n        % Packing:\n        pars.(name){end+1} = value;\n    else\n        % Saving the last:\n        pars.(name) = value;\n    end\nend\n"
  },
  {
    "path": "analysis/intrinsic_dimension_estimation/matlab_codes/private/DANCo_estimateKL.m",
    "content": "function kl = DANCo_estimateKL(k,dHat,mu,tau,dHat_ref,mu_ref,tau_ref)\n%DANCO_ESTIMATEKL Estimating the Kullback-Leibler divergence for DANCo\n%\n% function kl = DANCo_estimateKL(k,dHat,mu,tau,dHat_ref,mu_ref,tau_ref)\n%\n%  This function estimates the Kullback-Leibler divergence between the\n% distances/angles bivariate distributions over two datasets whose\n% statistic parameters are {dHat,mu,tau} and {dHat_ref,mu_ref,tau_ref}\n% respectively.\n%\n%  Parameters\n%  ----------\n% IN:\n%  k                        = The KNN parameter.\n%  dHat,mu,tau              = Statistics for the real dataset.\n%  dHat_ref,mu_ref,tau_ref  = Statistics for the synthetic dataset.\n\n    % Checking parameters:\n    if nargin<7; error('Too few parameters'); end\n\n    % Doing the computation:\n    kl = EstimateKL_dists(dHat,dHat_ref,k) + ...\n        EstimateKL_angles(mu,tau,mu_ref,tau_ref);\n    \nend\n\n% ------------------------ LOCAL FUNCTIONS ------------------------\n\n% Estimating the KL-divergence for all the cases d=1:D:\nfunction kl = EstimateKL_angles(m1,k1,m2,k2)\n    % The bessel values:\n    bk2 = besseli(0,k2);\n    bk1 = besseli(0,k1);\n    ak1 = besseli(1,k1);\n    ck1 = besseli(1,-k1);\n\n    % Checking for infs:\n    if(isinf(k1)); k1 = realmax; end\n    k2(isinf(k2)) = realmax;\n    bk2(isinf(bk2)) = realmax;\n    if(isinf(bk1)); bk1 = realmax; end\n    if(isinf(ak1)); ak1 = realmax; end\n    if(isinf(ck1)); ck1 = realmax; end\n    \n    % Computing the KL:\n    kl =  abs(log(bk2./bk1) + (ak1-ck1)./(2*bk1).*(k1-k2.*cos(m2-m1)));\nend\n\n% -----------------------------------------------------------------\n\n% Estimating the KL-divergence for the distances:\nfunction kl = EstimateKL_dists(d1,d2,k)\n    % Initializing:\n    bin = zeros(numel(d2),k+1);\n    \n    % Computing the terms in the summary:\n    for i=1:k+1\n        bin(:,i) = (-1)^(i-1)*nchoosek(k,i-1).*psi(1+(i-1)*(d1./d2)); \n    end\n    \n    % Estimating the KL in cloised form:\n    kl = -(1+harmonic(k-1)-harmonic(k)*(d2./d1)+log(d2./d1)+(k-1)*sum(bin,2)');\nend\n\n% -----------------------------------------------------------------\n\n% Harmonic numbers (by David Terr):\nfunction h = harmonic(z)\n    if z == 1\n        h = 1;\n    else\n        h = log(z) + 0.5772196649 + 1/(2*z) - 1/(12*z^2) + 1/(120*z^4) - 1/(252*z^6); \n    end\nend\n"
  },
  {
    "path": "analysis/intrinsic_dimension_estimation/matlab_codes/private/DANCo_statistics.m",
    "content": "function [isLowDim,dHat,mu,tau] = DANCo_statistics(data,k,params,force)\n%DANCO_STATISTICS Computes the statistics used by DANCo\n%\n% function [isLowDim,dHat,mu,tau] = DANCo_statistics(data,k,params,force)\n%\n%  This funciton allows to compute the statistic parameters used by DANCo\n% to estimate the intrinsic dimensionality, in cas of isLowDim==true only\n% the dHat value is computed (the other output values set to 0 in this\n% case).\n%\n%  Parameters\n%  ----------\n% IN:\n%  data     = The dataset as a matrix with columnwise vectors.\n%  k        = KNN parameter.\n%  params   = The named parameters passed to DANCo (see DANCo or DANCoFit).\n%  force    = Forces the computation of mu and tau. (def=false).\n% OUT:\n%  isLowDim =  True if the intrinsic dimensionality is considered low, in\n%             this case the dHat estimation is considered good enough.\n%  dHat     = Estimated id by means of biased esimators like MLE or MiND.\n%  mu,tau   = Von Mises circular statistic parameters.\n\n    % Checking parameters:\n    if nargin<3; error('A dataset, k, and the parameters are required'); end\n    if nargin<4; force=false; end\n    if not(isa(params,'struct')); error('Params must be a struct'); end\n    if not(isfield(params,'mlMethod')); error('An mlMethod is required'); end\n    if not(isfield(params,'minCorrectionDim')); error('A minCorrectionDim is required'); end\n    \n    % Infos:\n    N = size(data,2);\n    \n    % Is the knn already there?\n    mustDoKnn = true;\n    if isfield(params,'inds') && isfield(params,'dists')\n        % Extracting and checking:\n        inds = params.inds;\n        dists = params.dists;\n        if size(dists,1) == N && size(dists,2) >= k+2 && size(inds,1) == N && size(inds,2) >= k+2\n            % KNN already present:\n            dists = dists(:,1:k+2);\n            inds = inds(:,1:k+2);\n            mustDoKnn = false;\n        end\n    end\n    \n    % Finding the KNNs only if needed:\n    if mustDoKnn\n        % It's needed!\n        [inds,dists] = KNN(data,k,true);\n    end\n    \n    % Normalizing the distances:\n    if not(mustDoKnn || (isfield(params,'normalized') && params.normalized))\n        % Cropping both dists and inds:\n        dists = dists(:,2:end);\n        inds = inds(:,2:end);\n        \n        % Normalizing and removing the last element:\n        dists = dists(:,1:end-1)./repmat(dists(:,end),[1,size(dists,2)-1]);\n    end\n    \n    % Estimating dHat:\n    switch params.mlMethod\n        case 'MLE'\n            dHat = MLE(data,'dists',dists,'normalized',true);\n        case 'MiND_MLI'\n            dHat = MiND_ML(data,'dists',dists,'optimize',false,'normalized',true);\n        case 'MiND_MLK'\n            dHat = MiND_ML(data,'dists',dists,'optimize',true,'normalized',true);\n        otherwise\n            error(['Unknown mlMethod: ' params.mlMethod]);\n    end\n    \n    % Correction based on the KL-divergence:\n    isLowDim = dHat < params.minCorrectionDim;\n    if not(isLowDim) || force\n        % The circular statistics:\n        angles = ComputeAngles(data,inds);\n        [mu,tau] = CircStats(angles);\n\n        % Averaging the parameters:\n        mu = angle(sum(exp(1i*mu)));\n        tau = mean(tau);\n    else\n        % Fake parameters:\n        mu = 0;\n        tau = 0;\n    end\n\nend\n\n% ------------------------ LOCAL FUNCTIONS ------------------------\n\n% Computing the used angular statistics:\nfunction [mu,tau] = CircStats(angles)\n    % Computing the complex circular means:\n    cm = sum(exp(1i*angles),2);\n    \n    % Computing the von Mises parameters:\n    mu = angle(cm);\n    r = abs(cm)/size(angles,2);\n    \n    % The R function:\n    if r < 0.53; tau = 2*r + r.^3 + 5*r.^5./6; else\n        if r < 0.85; tau = -0.4 + 1.39*r + 0.43./(1-r);\n        else tau = 1./(r.^3 - 4*r.^2 + 3*r); end\n    end\n    \n    % Correcting for small sample sizes:\n    if size(angles,2) < 15\n        mask = tau < 2;\n        tau(mask) = max(tau(mask)-2./(size(angles,2).*tau(mask)),0);\n        tau(~mask) = (size(angles,2)-1).^3.*tau(~mask)./(size(angles,2).^3+size(angles,2));\n    end\nend\n\n% -----------------------------------------------------------------\n\n% Computing the pairwase angles for neighbours:\nfunction angles = ComputeAngles(data,inds)\n    % Init:\n    K = size(inds,2);\n    angles = zeros(size(inds,1),K*(K-1)/2);\n    cnk = combnk(1:size(inds,2),2);\n    \n    % Iterating on points:\n    for i=1:size(data,2)\n        % Translating and normalizing points:\n        pts = data(:,inds(i,:)) - repmat(data(:,i),[1,K]);\n        pts = pts./repmat(sqrt(sum(pts.^2,1)),[size(pts,1),1]);\n        \n        % Computing all the pairwise angles:\n        a = pts(:,cnk(:,1)); b = pts(:,cnk(:,2));\n        angles(i,:) = atan2(sqrt(sum((b-repmat(sum(a.*b,1),[size(a,1),1]).*a).^2,1)),sum(a.*b,1));\n    end\nend\n"
  },
  {
    "path": "analysis/intrinsic_dimension_estimation/matlab_codes/randsphere.m",
    "content": "function X = randsphere(d,N,r)\n% RANDSPHERE  Random points uniformly drawn from a spere with radius r\n%\n% function X = randsphere(d,N,r)\n%\n%  This function generates random data samples from a uniform hypersphere\n% with a given radius. The algorithm is the one reported by Roger Stafford\n% on matlabcentral/fileexchange.\n%\n%  Parameters\n%  ----------\n% IN:\n%  d    = Space dimensionality. (def=2)\n%  N    = Number of points. (def=1000)\n%  r    = Hypersphere radius. (def=1)\n% OUT:\n%  X    = Matrix of points (columnwise).\n\n    % Checking parameters:\n    if nargin<1; d=2; end\n    if nargin<2; N=1000; end\n    if nargin<3; r=1; end\n\n    % Gaussian randomly sampled points:\n    X = randn(d,N);\n    \n    % Squared norms:\n    s2 = sum(X.^2,1);\n    \n    % Correcting the points norms:\n    X = X.*repmat(r*(gammainc(s2/2,d/2).^(1/d))./sqrt(s2),d,1);\nend\n\n"
  },
  {
    "path": "analysis/intrinsic_dimension_estimation/methods.py",
    "content": "import numpy as np\nimport os\nfrom sklearn.neighbors import NearestNeighbors\n\n\ndef kNN(X, n_neighbors, n_jobs):\n    neigh = NearestNeighbors(n_neighbors=n_neighbors, n_jobs=n_jobs).fit(X)\n    dists, inds = neigh.kneighbors(X)\n    return dists, inds\n\n\ndef Levina_Bickel(X, dists, k):\n    m = np.log(dists[:, k:k+1] / dists[:, 1:k])\n    m = (k-2) / np.sum(m, axis=1)\n    dim = np.mean(m)\n    return dim\n\n\ndef start_matlab_engine():\n    import matlab.engine\n    eng = matlab.engine.start_matlab()\n    eng.addpath(os.path.join(os.path.split(__file__)[0], 'matlab_codes'))\n    return eng\n\n\ndef MiND_ML(X, dists, k):\n    import matlab.engine\n    eng = start_matlab_engine()\n    X_mat = matlab.double(X.T.tolist())\n    dists_mat = matlab.double(dists[:, :k+2].tolist())\n    dim = eng.MiND_ML(X_mat, 'dists', dists_mat, 'normalized', False, 'optimize', True)\n    return dim\n\n\ndef MiND_KL(X, dists, k, maxDim=30):\n    import matlab.engine\n    eng = start_matlab_engine()\n    X_mat = matlab.double(X.T.tolist())\n    dists_mat = matlab.double(dists[:, :k+2].tolist())\n    dim = eng.MiND_KL(X_mat, 'k', matlab.double([k]), 'maxDim', matlab.double([maxDim]),\n                      'dists', dists_mat, 'normalized', False, nargout=1)\n    return dim\n\n\ndef DANCo(X, dists, inds, k, maxDim=30):\n    import matlab.engine\n    eng = start_matlab_engine()\n    X_mat = matlab.double(X.T.tolist())\n    dists_mat = matlab.double(dists[:, :k+2].tolist())\n    inds_mat = matlab.int32((inds[:, :k+2]+1).tolist())  # fit Matlab indices\n    dim = eng.DANCo(X_mat, 'k', matlab.double([k]), 'maxDim', matlab.double([maxDim]), 'fractal', True, \n                    'dists', dists_mat, 'inds', inds_mat, 'normalized', False, nargout=1)\n    return dim\n\n\ndef Hein_CD(X):\n    import matlab.engine\n    eng = start_matlab_engine()\n    X_mat = matlab.double(X.T.tolist())\n    dim = eng.GetDim(X_mat, nargout=1)\n    dim = np.array(dim)[0]\n    return dim[0], dim[1]"
  },
  {
    "path": "analysis/latent_regression/__init__.py",
    "content": "from .regressors import *"
  },
  {
    "path": "analysis/latent_regression/regressors.py",
    "content": "import numpy as np\nfrom sklearn.preprocessing import StandardScaler\nfrom sklearn.decomposition import PCA\nfrom sklearn.neural_network import MLPRegressor\nfrom sklearn.linear_model import LinearRegression\n\n\ndef mlp_regress(X_train, X_test, y_train, y_test, random_state=0):\n    scaler = StandardScaler()\n    nn = MLPRegressor(hidden_layer_sizes=(64, 64, 64, 64, 64), activation='relu', solver='adam',\n                      learning_rate='adaptive', alpha=0.01, random_state=random_state)\n    X_train_scaled = scaler.fit_transform(X_train)\n    X_test_scaled = scaler.transform(X_test)\n    y_norm = np.max(np.abs(y_train))\n    nn.fit(X_train_scaled, y_train / y_norm)\n    train_error = np.mean(np.abs(nn.predict(X_train_scaled) * y_norm - y_train)) \n    test_error = np.mean(np.abs(nn.predict(X_test_scaled) * y_norm - y_test))\n    return train_error, test_error, {'scaler':scaler, 'model':nn}\n\n\ndef lin_regress(X_train, X_test, y_train, y_test):\n    reg = LinearRegression()\n    reg.fit(X_train, y_train)\n    train_error = np.mean(np.abs(reg.predict(X_train) - y_train))\n    test_error = np.mean(np.abs(reg.predict(X_test) - y_test))\n    return train_error, test_error, reg\n\n\ndef pca(X_train, X_test, num_components, random_state=0):\n    pca = PCA(num_components, random_state=random_state)\n    pca.fit(X_train)\n    return pca.transform(X_train), pca.transform(X_test), pca"
  },
  {
    "path": "configs/air_dancer/latentpred/config1.yaml",
    "content": "lr: 0.01\nseed: 1\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'latent-prediction'\ndata_filepath: './data/'\ndataset: 'air_dancer'\nlr_schedule: [20, 50, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/air_dancer/latentpred/config2.yaml",
    "content": "lr: 0.01\nseed: 2\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'latent-prediction'\ndata_filepath: './data/'\ndataset: 'air_dancer'\nlr_schedule: [20, 50, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/air_dancer/latentpred/config3.yaml",
    "content": "lr: 0.02\nseed: 3\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'latent-prediction'\ndata_filepath: './data/'\ndataset: 'air_dancer'\nlr_schedule: [20, 50, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/air_dancer/model/config1.yaml",
    "content": "lr: 0.001\nseed: 1\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'encoder-decoder'\ndata_filepath: './data/'\ndataset: 'air_dancer'\nlr_schedule: [20, 50, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/air_dancer/model/config2.yaml",
    "content": "lr: 0.001\nseed: 2\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'encoder-decoder'\ndata_filepath: './data/'\ndataset: 'air_dancer'\nlr_schedule: [20, 50, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/air_dancer/model/config3.yaml",
    "content": "lr: 0.001\nseed: 3\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'encoder-decoder'\ndata_filepath: './data/'\ndataset: 'air_dancer'\nlr_schedule: [20, 50, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/air_dancer/model64/config1.yaml",
    "content": "lr: 0.001\nseed: 1\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'encoder-decoder-64'\ndata_filepath: './data/'\ndataset: 'air_dancer'\nlr_schedule: [20, 50, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/air_dancer/model64/config2.yaml",
    "content": "lr: 0.001\nseed: 2\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'encoder-decoder-64'\ndata_filepath: './data/'\ndataset: 'air_dancer'\nlr_schedule: [20, 50, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/air_dancer/model64/config3.yaml",
    "content": "lr: 0.001\nseed: 3\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'encoder-decoder-64'\ndata_filepath: './data/'\ndataset: 'air_dancer'\nlr_schedule: [20, 50, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/air_dancer/refine64/config1.yaml",
    "content": "lr: 0.0006\nseed: 1\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'refine-64'\ndata_filepath: './data/'\ndataset: 'air_dancer'\nlr_schedule: [15, 30, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/air_dancer/refine64/config2.yaml",
    "content": "lr: 0.0003\nseed: 2\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'refine-64'\ndata_filepath: './data/'\ndataset: 'air_dancer'\nlr_schedule: [15, 30, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/air_dancer/refine64/config3.yaml",
    "content": "lr: 0.0004\nseed: 3\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'refine-64'\ndata_filepath: './data/'\ndataset: 'air_dancer'\nlr_schedule: [15, 30, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/circular_motion/model/config1.yaml",
    "content": "lr: 0.001\nseed: 1\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'encoder-decoder'\ndata_filepath: './data/'\ndataset: 'circular_motion'\nlr_schedule: [20, 50, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/circular_motion/model/config2.yaml",
    "content": "lr: 0.001\nseed: 2\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'encoder-decoder'\ndata_filepath: './data/'\ndataset: 'circular_motion'\nlr_schedule: [20, 50, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/circular_motion/model/config3.yaml",
    "content": "lr: 0.001\nseed: 3\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'encoder-decoder'\ndata_filepath: './data/'\ndataset: 'circular_motion'\nlr_schedule: [20, 50, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/circular_motion/model64/config1.yaml",
    "content": "lr: 0.001\nseed: 1\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'encoder-decoder-64'\ndata_filepath: './data/'\ndataset: 'circular_motion'\nlr_schedule: [20, 50, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/circular_motion/model64/config2.yaml",
    "content": "lr: 0.001\nseed: 2\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'encoder-decoder-64'\ndata_filepath: './data/'\ndataset: 'circular_motion'\nlr_schedule: [20, 50, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/circular_motion/model64/config3.yaml",
    "content": "lr: 0.001\nseed: 3\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'encoder-decoder-64'\ndata_filepath: './data/'\ndataset: 'circular_motion'\nlr_schedule: [20, 50, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/circular_motion/refine64/config1.yaml",
    "content": "lr: 0.0003\nseed: 1\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'refine-64'\ndata_filepath: './data/'\ndataset: 'circular_motion'\nlr_schedule: [15, 30, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/circular_motion/refine64/config2.yaml",
    "content": "lr: 0.0003\nseed: 2\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'refine-64'\ndata_filepath: './data/'\ndataset: 'circular_motion'\nlr_schedule: [15, 30, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/circular_motion/refine64/config3.yaml",
    "content": "lr: 0.0003\nseed: 3\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'refine-64'\ndata_filepath: './data/'\ndataset: 'circular_motion'\nlr_schedule: [15, 30, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/double_pendulum/latentpred/config1.yaml",
    "content": "lr: 0.03\nseed: 1\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'latent-prediction'\ndata_filepath: './data/'\ndataset: 'double_pendulum'\nlr_schedule: [20, 50, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/double_pendulum/latentpred/config2.yaml",
    "content": "lr: 0.03\nseed: 2\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'latent-prediction'\ndata_filepath: './data/'\ndataset: 'double_pendulum'\nlr_schedule: [20, 50, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/double_pendulum/latentpred/config3.yaml",
    "content": "lr: 0.03\nseed: 3\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'latent-prediction'\ndata_filepath: './data/'\ndataset: 'double_pendulum'\nlr_schedule: [20, 50, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/double_pendulum/model/config1.yaml",
    "content": "lr: 0.001\nseed: 1\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'encoder-decoder'\ndata_filepath: './data/'\ndataset: 'double_pendulum'\nlr_schedule: [20, 50, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/double_pendulum/model/config2.yaml",
    "content": "lr: 0.001\nseed: 2\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'encoder-decoder'\ndata_filepath: './data/'\ndataset: 'double_pendulum'\nlr_schedule: [20, 50, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/double_pendulum/model/config3.yaml",
    "content": "lr: 0.001\nseed: 3\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'encoder-decoder'\ndata_filepath: './data/'\ndataset: 'double_pendulum'\nlr_schedule: [20, 50, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/double_pendulum/model64/config1.yaml",
    "content": "lr: 0.001\nseed: 1\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'encoder-decoder-64'\ndata_filepath: './data/'\ndataset: 'double_pendulum'\nlr_schedule: [20, 50, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/double_pendulum/model64/config2.yaml",
    "content": "lr: 0.001\nseed: 2\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'encoder-decoder-64'\ndata_filepath: './data/'\ndataset: 'double_pendulum'\nlr_schedule: [20, 50, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/double_pendulum/model64/config3.yaml",
    "content": "lr: 0.001\nseed: 3\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'encoder-decoder-64'\ndata_filepath: './data/'\ndataset: 'double_pendulum'\nlr_schedule: [20, 50, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/double_pendulum/refine64/config1.yaml",
    "content": "lr: 0.0003\nseed: 1\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'refine-64'\ndata_filepath: './data/'\ndataset: 'double_pendulum'\nlr_schedule: [15, 30, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/double_pendulum/refine64/config2.yaml",
    "content": "lr: 0.0005\nseed: 2\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'refine-64'\ndata_filepath: './data/'\ndataset: 'double_pendulum'\nlr_schedule: [15, 30, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/double_pendulum/refine64/config3.yaml",
    "content": "lr: 0.0004\nseed: 3\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 256\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'refine-64'\ndata_filepath: './data/'\ndataset: 'double_pendulum'\nlr_schedule: [15, 30, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/elastic_pendulum/latentpred/config1.yaml",
    "content": "lr: 0.02\nseed: 1\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'latent-prediction'\ndata_filepath: './data/'\ndataset: 'elastic_pendulum'\nlr_schedule: [20, 50, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/elastic_pendulum/latentpred/config2.yaml",
    "content": "lr: 0.02\nseed: 2\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'latent-prediction'\ndata_filepath: './data/'\ndataset: 'elastic_pendulum'\nlr_schedule: [20, 50, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/elastic_pendulum/latentpred/config3.yaml",
    "content": "lr: 0.01\nseed: 3\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'latent-prediction'\ndata_filepath: './data/'\ndataset: 'elastic_pendulum'\nlr_schedule: [20, 50, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/elastic_pendulum/model/config1.yaml",
    "content": "lr: 0.001\nseed: 1\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'encoder-decoder'\ndata_filepath: './data/'\ndataset: 'elastic_pendulum'\nlr_schedule: [20, 50, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/elastic_pendulum/model/config2.yaml",
    "content": "lr: 0.001\nseed: 2\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'encoder-decoder'\ndata_filepath: './data/'\ndataset: 'elastic_pendulum'\nlr_schedule: [20, 50, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/elastic_pendulum/model/config3.yaml",
    "content": "lr: 0.001\nseed: 3\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'encoder-decoder'\ndata_filepath: './data/'\ndataset: 'elastic_pendulum'\nlr_schedule: [20, 50, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/elastic_pendulum/model64/config1.yaml",
    "content": "lr: 0.001\nseed: 1\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'encoder-decoder-64'\ndata_filepath: './data/'\ndataset: 'elastic_pendulum'\nlr_schedule: [20, 50, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/elastic_pendulum/model64/config2.yaml",
    "content": "lr: 0.001\nseed: 2\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'encoder-decoder-64'\ndata_filepath: './data/'\ndataset: 'elastic_pendulum'\nlr_schedule: [20, 50, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/elastic_pendulum/model64/config3.yaml",
    "content": "lr: 0.001\nseed: 3\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'encoder-decoder-64'\ndata_filepath: './data/'\ndataset: 'elastic_pendulum'\nlr_schedule: [20, 50, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/elastic_pendulum/refine64/config1.yaml",
    "content": "lr: 0.0007\nseed: 1\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'refine-64'\ndata_filepath: './data/'\ndataset: 'elastic_pendulum'\nlr_schedule: [15, 30, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/elastic_pendulum/refine64/config2.yaml",
    "content": "lr: 0.0004\nseed: 2\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'refine-64'\ndata_filepath: './data/'\ndataset: 'elastic_pendulum'\nlr_schedule: [15, 30, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/elastic_pendulum/refine64/config3.yaml",
    "content": "lr: 0.0002\nseed: 3\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 256\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'refine-64'\ndata_filepath: './data/'\ndataset: 'elastic_pendulum'\nlr_schedule: [15, 30, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/fire/latentpred/config1.yaml",
    "content": "lr: 0.008\nseed: 1\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'latent-prediction'\ndata_filepath: './data/'\ndataset: 'fire'\nlr_schedule: [20, 50, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/fire/latentpred/config2.yaml",
    "content": "lr: 0.02\nseed: 2\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'latent-prediction'\ndata_filepath: './data/'\ndataset: 'fire'\nlr_schedule: [20, 50, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/fire/latentpred/config3.yaml",
    "content": "lr: 0.009\nseed: 3\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'latent-prediction'\ndata_filepath: './data/'\ndataset: 'fire'\nlr_schedule: [20, 50, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/fire/model/config1.yaml",
    "content": "lr: 0.001\nseed: 1\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'encoder-decoder'\ndata_filepath: './data/'\ndataset: 'fire'\nlr_schedule: [20, 50, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/fire/model/config2.yaml",
    "content": "lr: 0.001\nseed: 2\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'encoder-decoder'\ndata_filepath: './data/'\ndataset: 'fire'\nlr_schedule: [20, 50, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/fire/model/config3.yaml",
    "content": "lr: 0.001\nseed: 3\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'encoder-decoder'\ndata_filepath: './data/'\ndataset: 'fire'\nlr_schedule: [20, 50, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/fire/model64/config1.yaml",
    "content": "lr: 0.001\nseed: 1\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'encoder-decoder-64'\ndata_filepath: './data/'\ndataset: 'fire'\nlr_schedule: [20, 50, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/fire/model64/config2.yaml",
    "content": "lr: 0.001\nseed: 2\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'encoder-decoder-64'\ndata_filepath: './data/'\ndataset: 'fire'\nlr_schedule: [20, 50, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/fire/model64/config3.yaml",
    "content": "lr: 0.001\nseed: 3\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'encoder-decoder-64'\ndata_filepath: './data/'\ndataset: 'fire'\nlr_schedule: [20, 50, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/fire/refine64/config1.yaml",
    "content": "lr: 0.0007\nseed: 1\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'refine-64'\ndata_filepath: './data/'\ndataset: 'fire'\nlr_schedule: [15, 30, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/fire/refine64/config2.yaml",
    "content": "lr: 0.0006\nseed: 2\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'refine-64'\ndata_filepath: './data/'\ndataset: 'fire'\nlr_schedule: [15, 30, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/fire/refine64/config3.yaml",
    "content": "lr: 0.0007\nseed: 3\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'refine-64'\ndata_filepath: './data/'\ndataset: 'fire'\nlr_schedule: [15, 30, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/lava_lamp/latentpred/config1.yaml",
    "content": "lr: 0.008\nseed: 1\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'latent-prediction'\ndata_filepath: './data/'\ndataset: 'lava_lamp'\nlr_schedule: [20, 50, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/lava_lamp/latentpred/config2.yaml",
    "content": "lr: 0.009\nseed: 2\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'latent-prediction'\ndata_filepath: './data/'\ndataset: 'lava_lamp'\nlr_schedule: [20, 50, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/lava_lamp/latentpred/config3.yaml",
    "content": "lr: 0.01\nseed: 3\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'latent-prediction'\ndata_filepath: './data/'\ndataset: 'lava_lamp'\nlr_schedule: [20, 50, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/lava_lamp/model/config1.yaml",
    "content": "lr: 0.001\nseed: 1\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'encoder-decoder'\ndata_filepath: './data/'\ndataset: 'lava_lamp'\nlr_schedule: [20, 50, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/lava_lamp/model/config2.yaml",
    "content": "lr: 0.001\nseed: 2\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'encoder-decoder'\ndata_filepath: './data/'\ndataset: 'lava_lamp'\nlr_schedule: [20, 50, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/lava_lamp/model/config3.yaml",
    "content": "lr: 0.001\nseed: 3\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'encoder-decoder'\ndata_filepath: './data/'\ndataset: 'lava_lamp'\nlr_schedule: [20, 50, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/lava_lamp/model64/config1.yaml",
    "content": "lr: 0.001\nseed: 1\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'encoder-decoder-64'\ndata_filepath: './data/'\ndataset: 'lava_lamp'\nlr_schedule: [20, 50, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/lava_lamp/model64/config2.yaml",
    "content": "lr: 0.001\nseed: 2\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'encoder-decoder-64'\ndata_filepath: './data/'\ndataset: 'lava_lamp'\nlr_schedule: [20, 50, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/lava_lamp/model64/config3.yaml",
    "content": "lr: 0.001\nseed: 3\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'encoder-decoder-64'\ndata_filepath: './data/'\ndataset: 'lava_lamp'\nlr_schedule: [20, 50, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/lava_lamp/refine64/config1.yaml",
    "content": "lr: 0.0006\nseed: 1\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'refine-64'\ndata_filepath: './data/'\ndataset: 'lava_lamp'\nlr_schedule: [15, 30, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/lava_lamp/refine64/config2.yaml",
    "content": "lr: 0.0004\nseed: 2\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'refine-64'\ndata_filepath: './data/'\ndataset: 'lava_lamp'\nlr_schedule: [15, 30, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/lava_lamp/refine64/config3.yaml",
    "content": "lr: 0.0005\nseed: 3\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'refine-64'\ndata_filepath: './data/'\ndataset: 'lava_lamp'\nlr_schedule: [15, 30, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/reaction_diffusion/latentpred/config1.yaml",
    "content": "lr: 0.01\nseed: 1\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'latent-prediction'\ndata_filepath: './data/'\ndataset: 'reaction_diffusion'\nlr_schedule: [20, 50, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/reaction_diffusion/latentpred/config2.yaml",
    "content": "lr: 0.03\nseed: 2\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'latent-prediction'\ndata_filepath: './data/'\ndataset: 'reaction_diffusion'\nlr_schedule: [20, 50, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/reaction_diffusion/latentpred/config3.yaml",
    "content": "lr: 0.01\nseed: 3\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'latent-prediction'\ndata_filepath: './data/'\ndataset: 'reaction_diffusion'\nlr_schedule: [20, 50, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/reaction_diffusion/model/config1.yaml",
    "content": "lr: 0.001\nseed: 1\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'encoder-decoder'\ndata_filepath: './data/'\ndataset: 'reaction_diffusion'\nlr_schedule: [20, 50, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/reaction_diffusion/model/config2.yaml",
    "content": "lr: 0.001\nseed: 2\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'encoder-decoder'\ndata_filepath: './data/'\ndataset: 'reaction_diffusion'\nlr_schedule: [20, 50, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/reaction_diffusion/model/config3.yaml",
    "content": "lr: 0.001\nseed: 3\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'encoder-decoder'\ndata_filepath: './data/'\ndataset: 'reaction_diffusion'\nlr_schedule: [20, 50, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/reaction_diffusion/model64/config1.yaml",
    "content": "lr: 0.001\nseed: 1\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'encoder-decoder-64'\ndata_filepath: './data/'\ndataset: 'reaction_diffusion'\nlr_schedule: [20, 50, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/reaction_diffusion/model64/config2.yaml",
    "content": "lr: 0.001\nseed: 2\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'encoder-decoder-64'\ndata_filepath: './data/'\ndataset: 'reaction_diffusion'\nlr_schedule: [20, 50, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/reaction_diffusion/model64/config3.yaml",
    "content": "lr: 0.001\nseed: 3\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'encoder-decoder-64'\ndata_filepath: './data/'\ndataset: 'reaction_diffusion'\nlr_schedule: [20, 50, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/reaction_diffusion/refine64/config1.yaml",
    "content": "lr: 0.0004\nseed: 1\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'refine-64'\ndata_filepath: './data/'\ndataset: 'reaction_diffusion'\nlr_schedule: [15, 30, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/reaction_diffusion/refine64/config2.yaml",
    "content": "lr: 0.0003\nseed: 2\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'refine-64'\ndata_filepath: './data/'\ndataset: 'reaction_diffusion'\nlr_schedule: [15, 30, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/reaction_diffusion/refine64/config3.yaml",
    "content": "lr: 0.0003\nseed: 3\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 256\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'refine-64'\ndata_filepath: './data/'\ndataset: 'reaction_diffusion'\nlr_schedule: [15, 30, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/single_pendulum/latentpred/config1.yaml",
    "content": "lr: 0.005\nseed: 1\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'latent-prediction'\ndata_filepath: './data/'\ndataset: 'single_pendulum'\nlr_schedule: [20, 50, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/single_pendulum/latentpred/config2.yaml",
    "content": "lr: 0.01\nseed: 2\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'latent-prediction'\ndata_filepath: './data/'\ndataset: 'single_pendulum'\nlr_schedule: [20, 50, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/single_pendulum/latentpred/config3.yaml",
    "content": "lr: 0.02\nseed: 3\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'latent-prediction'\ndata_filepath: './data/'\ndataset: 'single_pendulum'\nlr_schedule: [20, 50, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/single_pendulum/model/config1.yaml",
    "content": "lr: 0.001\nseed: 1\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'encoder-decoder'\ndata_filepath: './data/'\ndataset: 'single_pendulum'\nlr_schedule: [20, 50, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/single_pendulum/model/config2.yaml",
    "content": "lr: 0.001\nseed: 2\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'encoder-decoder'\ndata_filepath: './data/'\ndataset: 'single_pendulum'\nlr_schedule: [20, 50, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/single_pendulum/model/config3.yaml",
    "content": "lr: 0.001\nseed: 3\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'encoder-decoder'\ndata_filepath: './data/'\ndataset: 'single_pendulum'\nlr_schedule: [20, 50, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/single_pendulum/model64/config1.yaml",
    "content": "lr: 0.001\nseed: 1\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'encoder-decoder-64'\ndata_filepath: './data/'\ndataset: 'single_pendulum'\nlr_schedule: [20, 50, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/single_pendulum/model64/config2.yaml",
    "content": "lr: 0.001\nseed: 2\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'encoder-decoder-64'\ndata_filepath: './data/'\ndataset: 'single_pendulum'\nlr_schedule: [20, 50, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/single_pendulum/model64/config3.yaml",
    "content": "lr: 0.001\nseed: 3\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'encoder-decoder-64'\ndata_filepath: './data/'\ndataset: 'single_pendulum'\nlr_schedule: [20, 50, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/single_pendulum/refine64/config1.yaml",
    "content": "lr: 0.0003\nseed: 1\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'refine-64'\ndata_filepath: './data/'\ndataset: 'single_pendulum'\nlr_schedule: [15, 30, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/single_pendulum/refine64/config2.yaml",
    "content": "lr: 0.0003\nseed: 2\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'refine-64'\ndata_filepath: './data/'\ndataset: 'single_pendulum'\nlr_schedule: [15, 30, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/single_pendulum/refine64/config3.yaml",
    "content": "lr: 0.0004\nseed: 3\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'refine-64'\ndata_filepath: './data/'\ndataset: 'single_pendulum'\nlr_schedule: [15, 30, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/swingstick_magnetic/model/config1.yaml",
    "content": "lr: 0.001\nseed: 1\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'encoder-decoder'\ndata_filepath: './data/'\ndataset: 'swingstick_magnetic'\nlr_schedule: [20, 50, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/swingstick_magnetic/model/config2.yaml",
    "content": "lr: 0.001\nseed: 2\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'encoder-decoder'\ndata_filepath: './data/'\ndataset: 'swingstick_magnetic'\nlr_schedule: [20, 50, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/swingstick_magnetic/model/config3.yaml",
    "content": "lr: 0.001\nseed: 3\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'encoder-decoder'\ndata_filepath: './data/'\ndataset: 'swingstick_magnetic'\nlr_schedule: [20, 50, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/swingstick_magnetic/model64/config1.yaml",
    "content": "lr: 0.001\nseed: 1\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'encoder-decoder-64'\ndata_filepath: './data/'\ndataset: 'swingstick_magnetic'\nlr_schedule: [20, 50, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/swingstick_magnetic/model64/config2.yaml",
    "content": "lr: 0.001\nseed: 2\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'encoder-decoder-64'\ndata_filepath: './data/'\ndataset: 'swingstick_magnetic'\nlr_schedule: [20, 50, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/swingstick_magnetic/model64/config3.yaml",
    "content": "lr: 0.001\nseed: 3\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'encoder-decoder-64'\ndata_filepath: './data/'\ndataset: 'swingstick_magnetic'\nlr_schedule: [20, 50, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/swingstick_magnetic/refine64/config1.yaml",
    "content": "lr: 0.001\nseed: 1\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'refine-64'\ndata_filepath: './data/'\ndataset: 'swingstick_magnetic'\nlr_schedule: [15, 30, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/swingstick_magnetic/refine64/config2.yaml",
    "content": "lr: 0.001\nseed: 2\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'refine-64'\ndata_filepath: './data/'\ndataset: 'swingstick_magnetic'\nlr_schedule: [15, 30, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/swingstick_magnetic/refine64/config3.yaml",
    "content": "lr: 0.001\nseed: 3\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'refine-64'\ndata_filepath: './data/'\ndataset: 'swingstick_magnetic'\nlr_schedule: [15, 30, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/swingstick_non_magnetic/latentpred/config1.yaml",
    "content": "lr: 0.02\nseed: 1\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'latent-prediction'\ndata_filepath: './data/'\ndataset: 'swingstick_non_magnetic'\nlr_schedule: [20, 50, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/swingstick_non_magnetic/latentpred/config2.yaml",
    "content": "lr: 0.03\nseed: 2\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'latent-prediction'\ndata_filepath: './data/'\ndataset: 'swingstick_non_magnetic'\nlr_schedule: [20, 50, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/swingstick_non_magnetic/latentpred/config3.yaml",
    "content": "lr: 0.01\nseed: 3\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'latent-prediction'\ndata_filepath: './data/'\ndataset: 'swingstick_non_magnetic'\nlr_schedule: [20, 50, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/swingstick_non_magnetic/model/config1.yaml",
    "content": "lr: 0.001\nseed: 1\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'encoder-decoder'\ndata_filepath: './data/'\ndataset: 'swingstick_non_magnetic'\nlr_schedule: [20, 50, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/swingstick_non_magnetic/model/config2.yaml",
    "content": "lr: 0.001\nseed: 2\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'encoder-decoder'\ndata_filepath: './data/'\ndataset: 'swingstick_non_magnetic'\nlr_schedule: [20, 50, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/swingstick_non_magnetic/model/config3.yaml",
    "content": "lr: 0.001\nseed: 3\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'encoder-decoder'\ndata_filepath: './data/'\ndataset: 'swingstick_non_magnetic'\nlr_schedule: [20, 50, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/swingstick_non_magnetic/model64/config1.yaml",
    "content": "lr: 0.001\nseed: 1\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'encoder-decoder-64'\ndata_filepath: './data/'\ndataset: 'swingstick_non_magnetic'\nlr_schedule: [20, 50, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/swingstick_non_magnetic/model64/config2.yaml",
    "content": "lr: 0.001\nseed: 2\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'encoder-decoder-64'\ndata_filepath: './data/'\ndataset: 'swingstick_non_magnetic'\nlr_schedule: [20, 50, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/swingstick_non_magnetic/model64/config3.yaml",
    "content": "lr: 0.001\nseed: 3\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'encoder-decoder-64'\ndata_filepath: './data/'\ndataset: 'swingstick_non_magnetic'\nlr_schedule: [20, 50, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/swingstick_non_magnetic/refine64/config1.yaml",
    "content": "lr: 0.0005\nseed: 1\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'refine-64'\ndata_filepath: './data/'\ndataset: 'swingstick_non_magnetic'\nlr_schedule: [15, 30, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/swingstick_non_magnetic/refine64/config2.yaml",
    "content": "lr: 0.0001\nseed: 2\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'refine-64'\ndata_filepath: './data/'\ndataset: 'swingstick_non_magnetic'\nlr_schedule: [15, 30, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "configs/swingstick_non_magnetic/refine64/config3.yaml",
    "content": "lr: 0.0004\nseed: 3\nif_cuda: True\ngamma: 0.5\nlog_dir: 'logs'\ntrain_batch: 512\nval_batch: 256\ntest_batch: 256\nnum_workers: 8\nmodel_name: 'refine-64'\ndata_filepath: './data/'\ndataset: 'swingstick_non_magnetic'\nlr_schedule: [15, 30, 100, 300]\nnum_gpus: 1\nepochs: 1000"
  },
  {
    "path": "datainfo/README.md",
    "content": "## Data Collection\n\n- single_pendulum, elastic_pendulum, reaction_diffusion, circular_motion: simulation codes to generate the simulation data.\n- double_pendulum: post processings.\n- fire, lava_lamp: split long sequences.\n- utils: split long sequences.\n\n## Data Structure for Using Your Own Dataset\n\nAll of our datasets follow the structure below. Please prepare any new dataset as follows to use the current dataloader.\n\nIf you want to use a different file organization, you will need to write your own dataloader.\n\n```\n\\double_pendulum\n    \\0\n        \\0.png\n        \\1.png\n        \\...\n    \\1\n        \\...\n\n    \\2\n        \\...\n    ...\n```"
  },
  {
    "path": "datainfo/air_dancer/data_split_dict_1.json",
    "content": "{\n    \"test\": [\n        2,\n        18,\n        4\n    ],\n    \"val\": [\n        15,\n        3,\n        8\n    ],\n    \"train\": [\n        23,\n        24,\n        11,\n        10,\n        22,\n        1,\n        5,\n        25,\n        19,\n        13,\n        9,\n        17,\n        16,\n        0,\n        7,\n        21,\n        6,\n        12,\n        20,\n        14\n    ]\n}"
  },
  {
    "path": "datainfo/air_dancer/data_split_dict_2.json",
    "content": "{\n    \"test\": [\n        24,\n        2,\n        1\n    ],\n    \"val\": [\n        9,\n        5,\n        11\n    ],\n    \"train\": [\n        3,\n        20,\n        17,\n        15,\n        0,\n        23,\n        4,\n        7,\n        14,\n        16,\n        19,\n        22,\n        12,\n        18,\n        10,\n        13,\n        21,\n        25,\n        6,\n        8\n    ]\n}"
  },
  {
    "path": "datainfo/air_dancer/data_split_dict_3.json",
    "content": "{\n    \"test\": [\n        17,\n        18,\n        7\n    ],\n    \"val\": [\n        19,\n        11,\n        4\n    ],\n    \"train\": [\n        6,\n        16,\n        23,\n        14,\n        1,\n        5,\n        21,\n        10,\n        9,\n        13,\n        25,\n        12,\n        3,\n        8,\n        22,\n        20,\n        0,\n        2,\n        24,\n        15\n    ]\n}"
  },
  {
    "path": "datainfo/circular_motion/data_split_dict_1.json",
    "content": "{\n    \"test\": [\n        524,\n        1032,\n        960,\n        907,\n        977,\n        840,\n        877,\n        802,\n        392,\n        5,\n        746,\n        980,\n        623,\n        1068,\n        675,\n        1098,\n        931,\n        470,\n        361,\n        591,\n        867,\n        975,\n        352,\n        526,\n        414,\n        237,\n        561,\n        880,\n        942,\n        552,\n        1059,\n        789,\n        12,\n        1005,\n        464,\n        345,\n        348,\n        806,\n        631,\n        89,\n        961,\n        60,\n        1002,\n        758,\n        1046,\n        335,\n        221,\n        898,\n        177,\n        767,\n        751,\n        354,\n        848,\n        827,\n        497,\n        983,\n        70,\n        805,\n        1034,\n        1022,\n        581,\n        621,\n        388,\n        1039,\n        1072,\n        1025,\n        681,\n        247,\n        607,\n        380,\n        204,\n        852,\n        44,\n        593,\n        941,\n        448,\n        472,\n        707,\n        477,\n        1015,\n        896,\n        454,\n        59,\n        864,\n        443,\n        780,\n        18,\n        52,\n        45,\n        62,\n        650,\n        209,\n        468,\n        545,\n        912,\n        4,\n        886,\n        798,\n        58,\n        999,\n        192,\n        429,\n        777,\n        967,\n        920,\n        1014,\n        241,\n        522,\n        129,\n        275\n    ],\n    \"val\": [\n        456,\n        164,\n        736,\n        36,\n        149,\n        406,\n        1075,\n        230,\n        21,\n        1017,\n        442,\n        620,\n        214,\n        1096,\n        747,\n        259,\n        111,\n        264,\n        1089,\n        815,\n        431,\n        351,\n        395,\n        319,\n        24,\n        116,\n        485,\n        508,\n        329,\n        719,\n        465,\n        301,\n        728,\n        663,\n        279,\n        672,\n        172,\n        540,\n        1063,\n        163,\n        171,\n        71,\n        297,\n        1011,\n        189,\n        639,\n        944,\n        112,\n        1004,\n        255,\n        287,\n        773,\n        772,\n        14,\n        463,\n        17,\n        888,\n        85,\n        72,\n        689,\n        861,\n        33,\n        261,\n        836,\n        871,\n        816,\n        564,\n        817,\n        93,\n        881,\n        185,\n        598,\n        563,\n        181,\n        1079,\n        235,\n        823,\n        28,\n        614,\n        469,\n        339,\n        627,\n        1033,\n        638,\n        553,\n        551,\n        1,\n        1040,\n        424,\n        365,\n        832,\n        496,\n        423,\n        516,\n        963,\n        1019,\n        567,\n        583,\n        373,\n        890,\n        492,\n        57,\n        972,\n        436,\n        210,\n        574,\n        796,\n        531,\n        132,\n        828\n    ],\n    \"train\": [\n        391,\n        479,\n        51,\n        1077,\n        882,\n        625,\n        268,\n        197,\n        433,\n        635,\n        246,\n        685,\n        640,\n        803,\n        131,\n        398,\n        1070,\n        580,\n        924,\n        863,\n        962,\n        238,\n        39,\n        763,\n        653,\n        455,\n        608,\n        869,\n        680,\n        768,\n        634,\n        107,\n        514,\n        509,\n        884,\n        417,\n        219,\n        146,\n        311,\n        630,\n        996,\n        290,\n        80,\n        911,\n        366,\n        353,\n        130,\n        814,\n        1001,\n        765,\n        100,\n        565,\n        487,\n        498,\n        158,\n        147,\n        862,\n        887,\n        13,\n        529,\n        295,\n        444,\n        916,\n        795,\n        722,\n        260,\n        1016,\n        950,\n        609,\n        284,\n        895,\n        542,\n        535,\n        381,\n        450,\n        294,\n        750,\n        502,\n        494,\n        326,\n        556,\n        67,\n        1057,\n        384,\n        933,\n        143,\n        644,\n        362,\n        1084,\n        120,\n        16,\n        1094,\n        706,\n        788,\n        233,\n        905,\n        1036,\n        596,\n        733,\n        853,\n        418,\n        254,\n        119,\n        26,\n        559,\n        575,\n        549,\n        1088,\n        834,\n        966,\n        202,\n        434,\n        1095,\n        95,\n        308,\n        507,\n        590,\n        154,\n        988,\n        357,\n        1008,\n        425,\n        360,\n        234,\n        778,\n        819,\n        883,\n        90,\n        793,\n        611,\n        1028,\n        901,\n        812,\n        740,\n        432,\n        923,\n        68,\n        289,\n        167,\n        679,\n        784,\n        1093,\n        186,\n        671,\n        891,\n        849,\n        316,\n        729,\n        669,\n        213,\n        375,\n        801,\n        304,\n        991,\n        569,\n        9,\n        633,\n        603,\n        428,\n        190,\n        940,\n        642,\n        368,\n        771,\n        857,\n        282,\n        263,\n        419,\n        330,\n        1031,\n        420,\n        655,\n        876,\n        520,\n        1041,\n        544,\n        194,\n        791,\n        81,\n        122,\n        859,\n        3,\n        656,\n        1042,\n        201,\n        956,\n        753,\n        266,\n        262,\n        782,\n        243,\n        334,\n        54,\n        99,\n        341,\n        140,\n        280,\n        1027,\n        421,\n        787,\n        43,\n        401,\n        96,\n        115,\n        989,\n        1071,\n        973,\n        56,\n        734,\n        1061,\n        617,\n        850,\n        30,\n        928,\n        534,\n        573,\n        1067,\n        968,\n        949,\n        499,\n        343,\n        868,\n        990,\n        837,\n        242,\n        1092,\n        482,\n        1083,\n        992,\n        568,\n        350,\n        723,\n        969,\n        612,\n        467,\n        344,\n        810,\n        503,\n        1087,\n        693,\n        807,\n        619,\n        174,\n        717,\n        1076,\n        594,\n        29,\n        624,\n        586,\n        718,\n        336,\n        809,\n        926,\n        616,\n        441,\n        409,\n        1045,\n        1053,\n        954,\n        134,\n        451,\n        501,\n        537,\n        731,\n        533,\n        1010,\n        1006,\n        491,\n        708,\n        206,\n        688,\n        125,\n        927,\n        113,\n        770,\n        459,\n        709,\n        453,\n        390,\n        200,\n        713,\n        402,\n        906,\n        152,\n        538,\n        25,\n        379,\n        654,\n        766,\n        188,\n        267,\n        493,\n        661,\n        137,\n        776,\n        121,\n        764,\n        687,\n        430,\n        978,\n        570,\n        800,\n        47,\n        82,\n        701,\n        489,\n        959,\n        396,\n        762,\n        1020,\n        35,\n        216,\n        939,\n        50,\n        1007,\n        1026,\n        23,\n        948,\n        702,\n        438,\n        248,\n        987,\n        878,\n        77,\n        155,\n        278,\n        745,\n        135,\n        178,\n        998,\n        309,\n        37,\n        676,\n        759,\n        775,\n        870,\n        79,\n        1048,\n        539,\n        712,\n        699,\n        148,\n        412,\n        613,\n        986,\n        385,\n        34,\n        446,\n        484,\n        915,\n        207,\n        774,\n        979,\n        199,\n        657,\n        378,\n        269,\n        184,\n        371,\n        271,\n        856,\n        845,\n        394,\n        532,\n        359,\n        997,\n        673,\n        476,\n        739,\n        478,\n        236,\n        821,\n        515,\n        372,\n        665,\n        865,\n        695,\n        831,\n        965,\n        285,\n        879,\n        892,\n        958,\n        331,\n        1000,\n        666,\n        1090,\n        889,\n        374,\n        1091,\n        974,\n        276,\n        156,\n        179,\n        585,\n        643,\n        226,\n        935,\n        0,\n        714,\n        781,\n        790,\n        510,\n        1049,\n        257,\n        223,\n        447,\n        293,\n        651,\n        530,\n        825,\n        69,\n        332,\n        370,\n        1052,\n        548,\n        1065,\n        525,\n        6,\n        382,\n        291,\n        474,\n        726,\n        413,\n        55,\n        415,\n        452,\n        104,\n        196,\n        292,\n        756,\n        519,\n        1078,\n        215,\n        647,\n        716,\n        913,\n        779,\n        437,\n        1013,\n        582,\n        866,\n        748,\n        829,\n        571,\n        705,\n        299,\n        527,\n        139,\n        698,\n        277,\n        75,\n        606,\n        88,\n        1066,\n        838,\n        427,\n        124,\n        648,\n        321,\n        15,\n        725,\n        296,\n        486,\n        270,\n        97,\n        2,\n        286,\n        383,\n        193,\n        127,\n        808,\n        151,\n        338,\n        854,\n        483,\n        1080,\n        1050,\n        307,\n        677,\n        61,\n        108,\n        636,\n        700,\n        318,\n        684,\n        546,\n        166,\n        909,\n        84,\n        696,\n        1035,\n        462,\n        602,\n        481,\n        195,\n        760,\n        860,\n        123,\n        473,\n        435,\n        227,\n        517,\n        358,\n        900,\n        337,\n        495,\n        715,\n        315,\n        397,\n        785,\n        142,\n        936,\n        737,\n        622,\n        769,\n        660,\n        1074,\n        168,\n        393,\n        356,\n        403,\n        1018,\n        53,\n        833,\n        855,\n        659,\n        159,\n        180,\n        386,\n        641,\n        126,\n        87,\n        475,\n        169,\n        1058,\n        599,\n        844,\n        902,\n        208,\n        543,\n        918,\n        377,\n        523,\n        73,\n        595,\n        224,\n        281,\n        1024,\n        422,\n        1085,\n        490,\n        506,\n        724,\n        244,\n        231,\n        1064,\n        512,\n        953,\n        541,\n        952,\n        7,\n        994,\n        239,\n        662,\n        176,\n        410,\n        407,\n        730,\n        914,\n        572,\n        749,\n        141,\n        363,\n        63,\n        19,\n        841,\n        903,\n        1043,\n        1069,\n        830,\n        1038,\n        405,\n        11,\n        1023,\n        597,\n        449,\n        670,\n        678,\n        562,\n        440,\n        457,\n        683,\n        839,\n        253,\n        49,\n        161,\n        550,\n        946,\n        306,\n        182,\n        820,\n        566,\n        323,\n        32,\n        558,\n        145,\n        211,\n        1097,\n        300,\n        1012,\n        109,\n        312,\n        938,\n        144,\n        153,\n        183,\n        1037,\n        826,\n        743,\n        652,\n        513,\n        943,\n        157,\n        674,\n        505,\n        367,\n        921,\n        692,\n        22,\n        76,\n        847,\n        757,\n        930,\n        752,\n        804,\n        249,\n        20,\n        872,\n        250,\n        94,\n        792,\n        649,\n        1029,\n        874,\n        103,\n        342,\n        251,\n        310,\n        592,\n        682,\n        191,\n        799,\n        42,\n        668,\n        811,\n        232,\n        985,\n        658,\n        588,\n        92,\n        458,\n        91,\n        929,\n        738,\n        369,\n        252,\n        203,\n        314,\n        710,\n        554,\n        187,\n        265,\n        364,\n        480,\n        704,\n        632,\n        220,\n        256,\n        114,\n        466,\n        615,\n        324,\n        65,\n        64,\n        408,\n        320,\n        400,\n        460,\n        327,\n        610,\n        727,\n        10,\n        27,\n        908,\n        325,\n        667,\n        212,\n        102,\n        322,\n        488,\n        894,\n        741,\n        910,\n        555,\n        744,\n        445,\n        105,\n        842,\n        1056,\n        697,\n        919,\n        1009,\n        118,\n        165,\n        899,\n        600,\n        245,\n        897,\n        754,\n        843,\n        971,\n        947,\n        932,\n        686,\n        628,\n        1021,\n        735,\n        46,\n        110,\n        283,\n        1081,\n        893,\n        786,\n        577,\n        302,\n        993,\n        273,\n        83,\n        579,\n        229,\n        951,\n        584,\n        846,\n        824,\n        601,\n        629,\n        117,\n        349,\n        851,\n        150,\n        389,\n        74,\n        416,\n        40,\n        858,\n        813,\n        945,\n        1062,\n        797,\n        500,\n        732,\n        618,\n        783,\n        298,\n        904,\n        1054,\n        346,\n        376,\n        917,\n        995,\n        957,\n        340,\n        274,\n        794,\n        964,\n        170,\n        173,\n        136,\n        86,\n        41,\n        742,\n        66,\n        240,\n        1082,\n        970,\n        547,\n        703,\n        922,\n        560,\n        1051,\n        98,\n        818,\n        272,\n        218,\n        439,\n        347,\n        138,\n        576,\n        1047,\n        955,\n        160,\n        885,\n        288,\n        411,\n        626,\n        333,\n        934,\n        511,\n        981,\n        303,\n        399,\n        1055,\n        106,\n        504,\n        198,\n        605,\n        1073,\n        690,\n        587,\n        1044,\n        101,\n        355,\n        205,\n        387,\n        835,\n        521,\n        637,\n        720,\n        1086,\n        175,\n        471,\n        976,\n        222,\n        604,\n        38,\n        984,\n        8,\n        133,\n        258,\n        578,\n        426,\n        162,\n        761,\n        873,\n        317,\n        78,\n        937,\n        313,\n        48,\n        217,\n        128,\n        305,\n        755,\n        1030,\n        875,\n        646,\n        1003,\n        328,\n        822,\n        589,\n        691,\n        404,\n        31,\n        664,\n        536,\n        228,\n        461,\n        528,\n        711,\n        925,\n        645,\n        225,\n        1060,\n        557,\n        982,\n        694,\n        518,\n        721\n    ]\n}"
  },
  {
    "path": "datainfo/circular_motion/data_split_dict_2.json",
    "content": "{\n    \"test\": [\n        24,\n        688,\n        255,\n        176,\n        368,\n        371,\n        58,\n        32,\n        777,\n        734,\n        919,\n        433,\n        61,\n        901,\n        966,\n        215,\n        844,\n        250,\n        272,\n        874,\n        139,\n        534,\n        772,\n        108,\n        937,\n        896,\n        698,\n        232,\n        605,\n        1066,\n        50,\n        668,\n        588,\n        60,\n        217,\n        391,\n        17,\n        699,\n        154,\n        750,\n        1001,\n        425,\n        638,\n        832,\n        1032,\n        621,\n        633,\n        982,\n        1082,\n        340,\n        664,\n        454,\n        996,\n        935,\n        718,\n        1062,\n        931,\n        1057,\n        1025,\n        1020,\n        571,\n        1003,\n        511,\n        944,\n        818,\n        330,\n        1060,\n        741,\n        724,\n        1079,\n        849,\n        912,\n        372,\n        1052,\n        736,\n        1065,\n        1044,\n        279,\n        355,\n        665,\n        361,\n        48,\n        472,\n        483,\n        363,\n        336,\n        867,\n        778,\n        652,\n        952,\n        745,\n        56,\n        1090,\n        549,\n        1028,\n        911,\n        761,\n        1042,\n        805,\n        882,\n        324,\n        73,\n        434,\n        515,\n        631,\n        346,\n        739,\n        173,\n        187,\n        115\n    ],\n    \"val\": [\n        810,\n        639,\n        233,\n        1017,\n        82,\n        256,\n        75,\n        455,\n        731,\n        238,\n        252,\n        1041,\n        725,\n        520,\n        237,\n        650,\n        464,\n        98,\n        174,\n        165,\n        134,\n        34,\n        259,\n        995,\n        686,\n        143,\n        1049,\n        716,\n        18,\n        1048,\n        429,\n        620,\n        268,\n        265,\n        349,\n        924,\n        147,\n        335,\n        527,\n        498,\n        402,\n        799,\n        598,\n        530,\n        986,\n        895,\n        807,\n        458,\n        989,\n        104,\n        858,\n        323,\n        703,\n        676,\n        95,\n        230,\n        484,\n        157,\n        722,\n        636,\n        883,\n        411,\n        773,\n        270,\n        923,\n        46,\n        757,\n        619,\n        784,\n        564,\n        459,\n        315,\n        31,\n        500,\n        345,\n        292,\n        1098,\n        765,\n        760,\n        642,\n        630,\n        352,\n        4,\n        37,\n        155,\n        253,\n        813,\n        44,\n        603,\n        394,\n        1,\n        708,\n        535,\n        188,\n        752,\n        160,\n        958,\n        1033,\n        130,\n        261,\n        382,\n        21,\n        940,\n        746,\n        41,\n        25,\n        69,\n        977,\n        117,\n        84\n    ],\n    \"train\": [\n        521,\n        97,\n        212,\n        436,\n        465,\n        835,\n        410,\n        195,\n        591,\n        288,\n        988,\n        828,\n        450,\n        127,\n        42,\n        93,\n        92,\n        970,\n        873,\n        239,\n        424,\n        1086,\n        956,\n        310,\n        266,\n        824,\n        142,\n        717,\n        119,\n        727,\n        1008,\n        357,\n        871,\n        584,\n        64,\n        396,\n        817,\n        406,\n        57,\n        2,\n        681,\n        9,\n        249,\n        864,\n        512,\n        198,\n        245,\n        40,\n        473,\n        1089,\n        299,\n        182,\n        601,\n        379,\n        321,\n        955,\n        763,\n        486,\n        766,\n        22,\n        906,\n        744,\n        171,\n        316,\n        702,\n        467,\n        370,\n        898,\n        274,\n        820,\n        1015,\n        644,\n        443,\n        707,\n        13,\n        576,\n        764,\n        109,\n        838,\n        559,\n        158,\n        26,\n        670,\n        721,\n        329,\n        789,\n        798,\n        802,\n        767,\n        133,\n        431,\n        539,\n        325,\n        672,\n        485,\n        618,\n        140,\n        135,\n        438,\n        692,\n        408,\n        608,\n        997,\n        640,\n        29,\n        344,\n        572,\n        421,\n        748,\n        6,\n        327,\n        1039,\n        235,\n        308,\n        796,\n        635,\n        290,\n        581,\n        713,\n        178,\n        612,\n        943,\n        568,\n        866,\n        207,\n        546,\n        519,\n        917,\n        889,\n        214,\n        339,\n        111,\n        448,\n        35,\n        63,\n        248,\n        282,\n        405,\n        163,\n        629,\n        782,\n        226,\n        659,\n        850,\n        902,\n        487,\n        1058,\n        592,\n        540,\n        596,\n        505,\n        831,\n        801,\n        575,\n        1046,\n        800,\n        833,\n        963,\n        556,\n        758,\n        504,\n        347,\n        236,\n        206,\n        441,\n        682,\n        1085,\n        547,\n        1040,\n        507,\n        87,\n        1094,\n        580,\n        975,\n        604,\n        648,\n        569,\n        294,\n        7,\n        916,\n        607,\n        788,\n        228,\n        281,\n        573,\n        376,\n        219,\n        915,\n        241,\n        167,\n        208,\n        1080,\n        216,\n        529,\n        1075,\n        94,\n        690,\n        907,\n        542,\n        578,\n        146,\n        120,\n        968,\n        49,\n        862,\n        211,\n        680,\n        447,\n        815,\n        586,\n        1068,\n        488,\n        693,\n        306,\n        271,\n        269,\n        314,\n        43,\n        792,\n        797,\n        976,\n        417,\n        954,\n        191,\n        499,\n        151,\n        729,\n        854,\n        848,\n        552,\n        781,\n        307,\n        407,\n        826,\n        881,\n        985,\n        1097,\n        1069,\n        397,\n        423,\n        508,\n        168,\n        76,\n        998,\n        627,\n        439,\n        860,\n        258,\n        229,\n        105,\n        645,\n        590,\n        751,\n        377,\n        759,\n        791,\n        597,\n        80,\n        106,\n        887,\n        891,\n        386,\n        753,\n        1054,\n        322,\n        100,\n        809,\n        965,\n        300,\n        843,\n        661,\n        913,\n        899,\n        617,\n        283,\n        855,\n        567,\n        99,\n        842,\n        926,\n        737,\n        574,\n        634,\n        967,\n        440,\n        987,\n        948,\n        637,\n        960,\n        886,\n        416,\n        562,\n        51,\n        513,\n        1019,\n        662,\n        732,\n        803,\n        66,\n        356,\n        460,\n        419,\n        897,\n        152,\n        932,\n        193,\n        359,\n        945,\n        606,\n        476,\n        583,\n        660,\n        624,\n        786,\n        369,\n        466,\n        506,\n        110,\n        557,\n        403,\n        1012,\n        351,\n        378,\n        1091,\n        981,\n        276,\n        415,\n        0,\n        251,\n        427,\n        953,\n        432,\n        942,\n        770,\n        683,\n        616,\n        566,\n        877,\n        787,\n        677,\n        456,\n        257,\n        70,\n        15,\n        286,\n        71,\n        929,\n        33,\n        123,\n        689,\n        756,\n        518,\n        332,\n        827,\n        47,\n        780,\n        614,\n        790,\n        12,\n        876,\n        112,\n        632,\n        541,\n        671,\n        129,\n        845,\n        694,\n        234,\n        156,\n        201,\n        482,\n        490,\n        884,\n        348,\n        1022,\n        1021,\n        302,\n        380,\n        921,\n        738,\n        231,\n        1084,\n        197,\n        210,\n        723,\n        1016,\n        354,\n        205,\n        964,\n        1038,\n        517,\n        755,\n        593,\n        1064,\n        305,\n        918,\n        278,\n        691,\n        1005,\n        1006,\n        242,\n        675,\n        200,\n        27,\n        116,\n        492,\n        430,\n        1073,\n        1030,\n        949,\n        687,\n        59,\n        295,\n        267,\n        920,\n        613,\n        1009,\n        398,\n        227,\n        880,\n        890,\n        90,\n        1081,\n        951,\n        785,\n        865,\n        331,\n        190,\n        172,\n        23,\n        312,\n        1061,\n        122,\n        362,\n        126,\n        646,\n        175,\n        45,\n        442,\n        678,\n        563,\n        446,\n        128,\n        169,\n        857,\n        775,\n        719,\n        1070,\n        991,\n        806,\n        623,\n        754,\n        77,\n        615,\n        684,\n        145,\n        463,\n        199,\n        304,\n        714,\n        666,\n        816,\n        863,\n        277,\n        868,\n        223,\n        333,\n        841,\n        878,\n        532,\n        808,\n        469,\n        39,\n        183,\n        85,\n        179,\n        53,\n        14,\n        65,\n        861,\n        704,\n        565,\n        1010,\n        457,\n        973,\n        244,\n        225,\n        647,\n        461,\n        132,\n        859,\n        247,\n        202,\n        936,\n        972,\n        962,\n        343,\n        544,\n        318,\n        1007,\n        149,\n        747,\n        959,\n        611,\n        385,\n        930,\n        479,\n        360,\n        445,\n        365,\n        384,\n        1092,\n        328,\n        287,\n        350,\n        409,\n        387,\n        743,\n        555,\n        1083,\n        1011,\n        728,\n        189,\n        776,\n        194,\n        480,\n        243,\n        101,\n        162,\n        334,\n        1018,\n        643,\n        649,\n        137,\n        847,\n        298,\n        452,\n        749,\n        170,\n        994,\n        554,\n        730,\n        79,\n        1043,\n        326,\n        218,\n        373,\n        388,\n        1037,\n        1045,\n        62,\n        284,\n        971,\n        705,\n        836,\n        888,\n        983,\n        622,\n        91,\n        516,\n        153,\n        543,\n        957,\n        837,\n        892,\n        933,\n        922,\n        673,\n        319,\n        341,\n        311,\n        651,\n        264,\n        783,\n        309,\n        551,\n        742,\n        136,\n        510,\n        879,\n        946,\n        453,\n        577,\n        654,\n        550,\n        900,\n        814,\n        969,\n        641,\n        570,\n        289,\n        10,\n        1095,\n        301,\n        653,\n        390,\n        1051,\n        166,\n        55,\n        138,\n        978,\n        908,\n        1034,\n        1067,\n        81,\n        28,\n        471,\n        514,\n        273,\n        470,\n        762,\n        496,\n        829,\n        67,\n        658,\n        795,\n        3,\n        358,\n        74,\n        209,\n        478,\n        177,\n        579,\n        525,\n        667,\n        493,\n        834,\n        8,\n        141,\n        626,\n        501,\n        779,\n        594,\n        600,\n        669,\n        204,\n        905,\n        812,\n        856,\n        220,\n        395,\n        910,\n        582,\n        221,\n        655,\n        1047,\n        823,\n        181,\n        822,\n        1026,\n        846,\n        320,\n        11,\n        894,\n        1024,\n        412,\n        192,\n        938,\n        545,\n        934,\n        83,\n        338,\n        560,\n        875,\n        1072,\n        710,\n        925,\n        1002,\n        414,\n        280,\n        494,\n        54,\n        522,\n        999,\n        184,\n        392,\n        609,\n        30,\n        451,\n        477,\n        148,\n        587,\n        367,\n        413,\n        240,\n        1071,\n        161,\n        1096,\n        263,\n        526,\n        254,\n        528,\n        1074,\n        720,\n        872,\n        553,\n        159,\n        114,\n        293,\n        536,\n        992,\n        449,\n        337,\n        711,\n        903,\n        774,\n        399,\n        735,\n        131,\n        38,\n        663,\n        697,\n        88,\n        68,\n        426,\n        150,\n        558,\n        180,\n        364,\n        72,\n        1014,\n        589,\n        599,\n        86,\n        715,\n        695,\n        712,\n        437,\n        561,\n        1055,\n        610,\n        1078,\n        974,\n        852,\n        769,\n        303,\n        696,\n        503,\n        909,\n        1063,\n        489,\n        870,\n        495,\n        275,\n        383,\n        144,\n        468,\n        366,\n        297,\n        509,\n        1013,\n        947,\n        1050,\n        481,\n        674,\n        1035,\n        1059,\n        1087,\n        709,\n        840,\n        400,\n        656,\n        342,\n        1093,\n        939,\n        125,\n        853,\n        260,\n        124,\n        404,\n        353,\n        771,\n        904,\n        794,\n        733,\n        401,\n        979,\n        585,\n        118,\n        196,\n        502,\n        941,\n        224,\n        78,\n        740,\n        804,\n        811,\n        927,\n        291,\n        5,\n        885,\n        628,\n        602,\n        213,\n        474,\n        497,\n        1077,\n        420,\n        1029,\n        462,\n        16,\n        990,\n        793,\n        203,\n        313,\n        851,\n        103,\n        422,\n        839,\n        1053,\n        950,\n        381,\n        706,\n        296,\n        375,\n        625,\n        121,\n        531,\n        19,\n        374,\n        491,\n        821,\n        679,\n        96,\n        185,\n        830,\n        537,\n        428,\n        52,\n        1088,\n        595,\n        980,\n        523,\n        435,\n        444,\n        825,\n        1036,\n        1004,\n        1076,\n        701,\n        1023,\n        389,\n        657,\n        548,\n        317,\n        914,\n        475,\n        685,\n        533,\n        984,\n        222,\n        107,\n        893,\n        869,\n        186,\n        20,\n        102,\n        928,\n        246,\n        89,\n        1056,\n        524,\n        113,\n        164,\n        418,\n        393,\n        36,\n        1027,\n        961,\n        768,\n        538,\n        285,\n        1000,\n        700,\n        262,\n        993,\n        726,\n        1031,\n        819\n    ]\n}"
  },
  {
    "path": "datainfo/circular_motion/data_split_dict_3.json",
    "content": "{\n    \"test\": [\n        887,\n        659,\n        395,\n        532,\n        890,\n        471,\n        385,\n        386,\n        882,\n        918,\n        141,\n        981,\n        368,\n        321,\n        347,\n        888,\n        1003,\n        43,\n        706,\n        159,\n        269,\n        625,\n        298,\n        417,\n        994,\n        202,\n        971,\n        32,\n        548,\n        614,\n        110,\n        78,\n        7,\n        317,\n        1065,\n        483,\n        571,\n        677,\n        773,\n        92,\n        90,\n        243,\n        850,\n        1082,\n        601,\n        41,\n        1085,\n        840,\n        136,\n        704,\n        181,\n        990,\n        987,\n        129,\n        254,\n        583,\n        546,\n        432,\n        213,\n        668,\n        334,\n        572,\n        58,\n        689,\n        475,\n        834,\n        718,\n        790,\n        1038,\n        862,\n        1076,\n        893,\n        528,\n        444,\n        1013,\n        278,\n        73,\n        199,\n        748,\n        274,\n        910,\n        808,\n        874,\n        793,\n        968,\n        551,\n        63,\n        616,\n        87,\n        326,\n        131,\n        31,\n        798,\n        1071,\n        310,\n        474,\n        308,\n        813,\n        975,\n        963,\n        392,\n        479,\n        531,\n        960,\n        26,\n        134,\n        970,\n        757,\n        267,\n        487\n    ],\n    \"val\": [\n        223,\n        897,\n        759,\n        344,\n        81,\n        173,\n        876,\n        829,\n        448,\n        229,\n        853,\n        691,\n        341,\n        329,\n        930,\n        104,\n        99,\n        714,\n        665,\n        445,\n        694,\n        191,\n        335,\n        244,\n        275,\n        823,\n        669,\n        227,\n        512,\n        1022,\n        1094,\n        752,\n        700,\n        582,\n        27,\n        832,\n        1056,\n        107,\n        996,\n        806,\n        307,\n        270,\n        988,\n        864,\n        936,\n        776,\n        320,\n        189,\n        372,\n        1039,\n        327,\n        615,\n        606,\n        305,\n        467,\n        643,\n        257,\n        378,\n        21,\n        692,\n        62,\n        974,\n        22,\n        501,\n        755,\n        285,\n        723,\n        623,\n        950,\n        939,\n        695,\n        361,\n        477,\n        340,\n        642,\n        648,\n        61,\n        1037,\n        647,\n        603,\n        630,\n        995,\n        20,\n        322,\n        593,\n        425,\n        807,\n        11,\n        1005,\n        561,\n        1083,\n        533,\n        264,\n        447,\n        1035,\n        958,\n        1030,\n        732,\n        737,\n        649,\n        441,\n        277,\n        519,\n        830,\n        1088,\n        635,\n        105,\n        1050,\n        697,\n        609\n    ],\n    \"train\": [\n        401,\n        148,\n        464,\n        711,\n        957,\n        1084,\n        420,\n        1007,\n        318,\n        1074,\n        579,\n        972,\n        434,\n        2,\n        770,\n        220,\n        1010,\n        540,\n        192,\n        740,\n        857,\n        788,\n        545,\n        408,\n        352,\n        163,\n        469,\n        118,\n        758,\n        534,\n        506,\n        587,\n        504,\n        396,\n        863,\n        760,\n        231,\n        794,\n        1059,\n        1049,\n        768,\n        655,\n        456,\n        279,\n        313,\n        459,\n        769,\n        371,\n        80,\n        747,\n        66,\n        67,\n        525,\n        804,\n        521,\n        144,\n        1081,\n        894,\n        1060,\n        1066,\n        682,\n        439,\n        476,\n        142,\n        576,\n        620,\n        38,\n        672,\n        909,\n        345,\n        328,\n        1019,\n        193,\n        1087,\n        713,\n        179,\n        413,\n        156,\n        607,\n        640,\n        898,\n        1046,\n        511,\n        1027,\n        253,\n        558,\n        135,\n        781,\n        1098,\n        599,\n        945,\n        1073,\n        157,\n        905,\n        221,\n        197,\n        177,\n        59,\n        219,\n        154,\n        178,\n        486,\n        421,\n        727,\n        64,\n        53,\n        402,\n        1097,\n        355,\n        751,\n        228,\n        646,\n        631,\n        852,\n        452,\n        627,\n        1052,\n        309,\n        292,\n        462,\n        325,\n        784,\n        555,\n        83,\n        225,\n        149,\n        350,\n        1025,\n        1031,\n        468,\n        427,\n        969,\n        49,\n        846,\n        836,\n        262,\n        473,\n        1012,\n        520,\n        12,\n        573,\n        60,\n        212,\n        1048,\n        973,\n        820,\n        671,\n        374,\n        460,\n        249,\n        859,\n        171,\n        172,\n        198,\n        218,\n        407,\n        746,\n        908,\n        502,\n        112,\n        1036,\n        450,\n        594,\n        1089,\n        999,\n        931,\n        809,\n        301,\n        137,\n        236,\n        819,\n        639,\n        753,\n        843,\n        1002,\n        1069,\n        75,\n        207,\n        100,\n        214,\n        300,\n        346,\n        720,\n        8,\n        578,\n        89,\n        188,\n        272,\n        1070,\n        418,\n        84,\n        725,\n        167,\n        238,\n        875,\n        800,\n        470,\n        390,\n        1055,\n        835,\n        211,\n        312,\n        217,\n        656,\n        0,\n        405,\n        151,\n        941,\n        1026,\n        174,\n        375,\n        586,\n        304,\n        814,\n        315,\n        265,\n        855,\n        1072,\n        929,\n        143,\n        565,\n        778,\n        485,\n        779,\n        337,\n        201,\n        562,\n        399,\n        186,\n        709,\n        500,\n        608,\n        1080,\n        362,\n        29,\n        743,\n        621,\n        42,\n        48,\n        121,\n        949,\n        1040,\n        696,\n        266,\n        557,\n        685,\n        6,\n        690,\n        624,\n        10,\n        1044,\n        955,\n        952,\n        719,\n        424,\n        16,\n        680,\n        162,\n        415,\n        287,\n        575,\n        821,\n        411,\n        363,\n        780,\n        503,\n        754,\n        71,\n        14,\n        86,\n        935,\n        702,\n        394,\n        299,\n        717,\n        721,\n        921,\n        589,\n        977,\n        693,\n        658,\n        889,\n        916,\n        496,\n        114,\n        296,\n        185,\n        938,\n        842,\n        712,\n        792,\n        772,\n        19,\n        233,\n        454,\n        224,\n        1018,\n        896,\n        789,\n        966,\n        155,\n        698,\n        106,\n        216,\n        687,\n        872,\n        242,\n        449,\n        404,\n        82,\n        535,\n        480,\n        209,\n        998,\n        482,\n        629,\n        248,\n        113,\n        1041,\n        705,\n        1016,\n        673,\n        775,\n        40,\n        17,\n        370,\n        1034,\n        68,\n        841,\n        168,\n        165,\n        507,\n        9,\n        801,\n        762,\n        39,\n        259,\n        364,\n        316,\n        633,\n        803,\n        25,\n        319,\n        666,\n        239,\n        708,\n        914,\n        728,\n        145,\n        1057,\n        412,\n        815,\n        472,\n        518,\n        699,\n        108,\n        619,\n        397,\n        735,\n        564,\n        652,\n        684,\n        745,\n        251,\n        119,\n        338,\n        951,\n        943,\n        379,\n        97,\n        797,\n        387,\n        559,\n        1061,\n        466,\n        458,\n        351,\n        844,\n        1032,\n        324,\n        158,\n        679,\n        282,\n        161,\n        443,\n        208,\n        828,\n        15,\n        899,\n        605,\n        526,\n        103,\n        670,\n        653,\n        1045,\n        543,\n        451,\n        517,\n        597,\n        724,\n        976,\n        303,\n        65,\n        516,\n        1000,\n        574,\n        343,\n        1047,\n        290,\n        96,\n        357,\n        674,\n        416,\n        654,\n        667,\n        288,\n        617,\n        1058,\n        1095,\n        513,\n        437,\n        77,\n        57,\n        764,\n        498,\n        457,\n        736,\n        414,\n        822,\n        993,\n        206,\n        170,\n        928,\n        263,\n        393,\n        356,\n        681,\n        831,\n        628,\n        774,\n        937,\n        816,\n        913,\n        196,\n        138,\n        956,\n        750,\n        660,\n        200,\n        923,\n        294,\n        1064,\n        283,\n        510,\n        1028,\n        366,\n        707,\n        810,\n        947,\n        650,\n        595,\n        610,\n        246,\n        234,\n        95,\n        331,\n        153,\n        284,\n        854,\n        932,\n        638,\n        1001,\n        783,\n        169,\n        481,\n        72,\n        286,\n        55,\n        1075,\n        817,\n        901,\n        911,\n        455,\n        926,\n        91,\n        76,\n        741,\n        373,\n        489,\n        756,\n        777,\n        380,\n        505,\n        927,\n        967,\n        596,\n        367,\n        536,\n        180,\n        560,\n        252,\n        891,\n        1051,\n        1009,\n        585,\n        289,\n        538,\n        552,\n        946,\n        1024,\n        917,\n        644,\n        730,\n        591,\n        877,\n        377,\n        686,\n        895,\n        256,\n        989,\n        115,\n        442,\n        1020,\n        281,\n        924,\n        676,\n        786,\n        400,\n        742,\n        215,\n        984,\n        497,\n        388,\n        446,\n        339,\n        342,\n        30,\n        824,\n        997,\n        463,\n        365,\n        839,\n        851,\n        920,\n        664,\n        120,\n        866,\n        93,\n        881,\n        922,\n        925,\n        641,\n        884,\n        261,\n        845,\n        900,\n        140,\n        1015,\n        194,\n        867,\n        260,\n        865,\n        749,\n        715,\n        235,\n        478,\n        811,\n        767,\n        74,\n        802,\n        688,\n        570,\n        1090,\n        210,\n        491,\n        622,\n        147,\n        1067,\n        892,\n        645,\n        731,\n        176,\n        205,\n        391,\n        490,\n        1068,\n        23,\n        314,\n        222,\n        598,\n        791,\n        302,\n        550,\n        237,\n        703,\n        744,\n        150,\n        965,\n        273,\n        139,\n        190,\n        733,\n        453,\n        612,\n        611,\n        509,\n        98,\n        166,\n        152,\n        410,\n        837,\n        495,\n        529,\n        983,\n        878,\n        964,\n        164,\n        422,\n        240,\n        907,\n        661,\n        85,\n        701,\n        569,\n        358,\n        406,\n        613,\n        376,\n        785,\n        403,\n        959,\n        465,\n        255,\n        636,\n        247,\n        79,\n        184,\n        954,\n        332,\n        563,\n        782,\n        566,\n        102,\n        567,\n        109,\n        796,\n        541,\n        523,\n        663,\n        37,\n        1029,\n        306,\n        592,\n        291,\n        493,\n        117,\n        1093,\n        683,\n        787,\n        24,\n        183,\n        384,\n        101,\n        675,\n        94,\n        761,\n        45,\n        886,\n        241,\n        146,\n        1096,\n        5,\n        1033,\n        985,\n        544,\n        577,\n        1078,\n        111,\n        1092,\n        871,\n        795,\n        726,\n        860,\n        124,\n        268,\n        982,\n        383,\n        333,\n        28,\n        133,\n        436,\n        1004,\n        435,\n        738,\n        604,\n        276,\n        1017,\n        879,\n        849,\n        18,\n        716,\n        428,\n        430,\n        739,\n        508,\n        953,\n        381,\n        915,\n        979,\n        833,\n        44,\n        651,\n        722,\n        554,\n        880,\n        618,\n        1063,\n        382,\n        1006,\n        182,\n        580,\n        883,\n        389,\n        3,\n        568,\n        130,\n        126,\n        438,\n        336,\n        870,\n        271,\n        369,\n        311,\n        600,\n        398,\n        662,\n        991,\n        359,\n        1023,\n        869,\n        160,\n        323,\n        942,\n        514,\n        527,\n        88,\n        729,\n        827,\n        494,\n        70,\n        766,\n        127,\n        47,\n        1042,\n        56,\n        1,\n        330,\n        484,\n        52,\n        433,\n        992,\n        539,\n        632,\n        903,\n        1043,\n        848,\n        51,\n        409,\n        584,\n        934,\n        499,\n        1021,\n        1079,\n        280,\n        245,\n        799,\n        1086,\n        547,\n        125,\n        1014,\n        226,\n        360,\n        948,\n        1054,\n        553,\n        258,\n        128,\n        818,\n        116,\n        626,\n        54,\n        838,\n        549,\n        537,\n        961,\n        904,\n        902,\n        919,\n        515,\n        175,\n        873,\n        861,\n        492,\n        13,\n        50,\n        590,\n        440,\n        203,\n        524,\n        710,\n        35,\n        46,\n        250,\n        122,\n        293,\n        602,\n        69,\n        232,\n        349,\n        556,\n        295,\n        1091,\n        765,\n        678,\n        771,\n        933,\n        763,\n        33,\n        906,\n        734,\n        986,\n        1062,\n        522,\n        637,\n        488,\n        4,\n        204,\n        1008,\n        423,\n        36,\n        419,\n        581,\n        429,\n        297,\n        426,\n        978,\n        354,\n        1053,\n        885,\n        812,\n        530,\n        1011,\n        431,\n        132,\n        940,\n        353,\n        634,\n        825,\n        1077,\n        657,\n        847,\n        868,\n        348,\n        944,\n        187,\n        588,\n        858,\n        856,\n        826,\n        962,\n        195,\n        542,\n        34,\n        123,\n        805,\n        230,\n        980,\n        461,\n        912\n    ]\n}"
  },
  {
    "path": "datainfo/collect/circular_motion/make_data.py",
    "content": "import numpy as np\nimport os\nfrom tqdm import tqdm\nimport shutil\nfrom PIL import Image, ImageDraw\n\ndef mkdir(folder):\n    if os.path.exists(folder):\n        shutil.rmtree(folder)\n    os.makedirs(folder)\n\n\n# background and object properties\nimg_size = (128, 128)\nbg_color = 'gray'\nobj_size = 8\nobj_shape = 'ball'\nobj_color ='blue'\n\n\ndef coord2img(x, y):\n    img = Image.new('RGB', img_size, bg_color)\n    draw = ImageDraw.Draw(img)\n    \n    pos_x = int(x * img_size[0])\n    pos_y = int((1-y) * img_size[1])\n    pos = (pos_x-obj_size, pos_y-obj_size, pos_x+obj_size, pos_y+obj_size)\n    \n    if obj_shape == 'square':\n        draw.rectangle(pos, fill=obj_color)\n    elif obj_shape == 'ball':\n        draw.ellipse(pos, fill=obj_color)\n    else:\n        assert False\n    \n    return img\n\n\ndata_filepath = './circular_motion'\nnum_exp = 1100\nnum_frames = 60\n\n# data_filepath = './circular_motion_long'\n# num_exp = 1\n# num_frames = 60*20\n\ncenter = (0.5, 0.5)\nradius = 0.3\nnp.random.seed(0)\n\nfor p_exp in tqdm(range(num_exp)):\n    theta = np.random.uniform(0, 2*np.pi)\n    theta_d = np.random.choice((-1, 1)) * np.random.uniform(np.pi/30, np.pi/10)\n\n    seq_filepath = os.path.join(data_filepath, str(p_exp))\n    mkdir(seq_filepath)\n\n    for p_frame in range(num_frames):\n        x = center[0] + radius * np.cos(theta)\n        y = center[1] + radius * np.sin(theta)\n        img = coord2img(x, y)\n        theta += theta_d\n        img.save(os.path.join(seq_filepath, str(p_frame)+'.png'))\t"
  },
  {
    "path": "datainfo/collect/double_pendulum/README.md",
    "content": "## Double Pendulum Data Collection\n\nSuppose the raw videos recording the double pendulum motion are stored in ```./visphysics_data/double_pendulum_raw_videos``` in the current folder.\nRun the following commands.\nThe collected data are saved in ```./visphysics_data/double_pendulum```.\n\n1. python convert_video.py -f 60 (convert each video to a sequence of images with sampling frequency 60fps)\n2. python equalize_background.py -r 215 -g 205 -b 192 (equalize the background color)\n3. python split_data.py (split each sequence \ninto subsequences with fixed length 60)"
  },
  {
    "path": "datainfo/collect/double_pendulum/convert_video.py",
    "content": "import cv2\nimport os\nimport shutil\nfrom tqdm import tqdm\nimport argparse\nimport random\nimport time\n\n## This script is for the purposes of extracting frames from given set of input videos \n## at desired fps.\n\n\ndef mkdir(folder):\n    if os.path.exists(folder):\n        shutil.rmtree(folder)\n    os.makedirs(folder)\n\n\ndef get_arguments():\n    ap = argparse.ArgumentParser()\n    ap.add_argument('-f', '--fps', required=True, type=int,\n                    help='frames per second at which it is desirable to get output data')\n    ap.add_argument('-s', '--seed', default=0, type=int,\n                    help='random seed for shuffling')\n\n    args = ap.parse_args()       \n    return args\n\n\ndef write_frames(video_path, frame_path, fps, seed):\n    # get list of all videos in video_path\n    listing = os.listdir(video_path)\n    # remove .DS_Store files\n    if listing.__contains__('.DS_Store'):\n        listing.remove('.DS_Store')\n    # random suffle\n    random.seed(seed)\n    random.shuffle(listing)\n\n    num_videos = len(listing)\n    for (idx, filename) in enumerate(listing):\n        print(idx+1, 'out of ', num_videos)\n        cap = cv2.VideoCapture(os.path.join(video_path, filename))\n        video_fps = cap.get(cv2.CAP_PROP_FPS)\n        print('video fps: ', video_fps)\n        \n        data_path = os.path.join(frame_path, str(idx))\n        mkdir(data_path)\n\n        p_frame = 0 # frame count\n        skip = round(video_fps / fps)   # sampling frequency\n        while(True):\n            # read current frame\n            ret, frame = cap.read()\n            if ret == False:\n                break\n\n            # manual crop\n            frame = frame[0:620, 371:991]\n            # downsize\n            frame = cv2.resize(frame, (128, 128), interpolation = cv2.INTER_AREA)\n\n            if p_frame % skip == 0:\n                cv2.imwrite(os.path.join(data_path, str(int(p_frame/skip))+'.png'), frame)\n            p_frame += 1\n\n        cap.release()\n        cv2.destroyAllWindows()\n        time.sleep(1.0) # short delay for cleaning purposes\n\n\nvideo_path = \"./visphysics_data/double_pendulum_raw_videos\"\nframe_path = \"./visphysics_data/double_pendulum_raw_data\"\nargs  = get_arguments()\nwrite_frames(video_path, frame_path, args.fps, args.seed)"
  },
  {
    "path": "datainfo/collect/double_pendulum/equalize_background.py",
    "content": "import cv2\nimport os\nimport shutil\nfrom tqdm import tqdm\nimport argparse\nimport numpy as np\n\n'''\nThis script equalizes backgrounds of all frames with a uniform background.\nThe background color is user-specified or computed by averaging \nbackground pixels of sampled frames from the data.\n'''\n\n\ndef mkdir(folder):\n    if os.path.exists(folder):\n        shutil.rmtree(folder)\n    os.makedirs(folder)\n\n\ndef get_arguments():\n    ap = argparse.ArgumentParser()\n    ap.add_argument('-b', '--channel0', help='b channel')\n    ap.add_argument('-g', '--channel1', help='g channel')\n    ap.add_argument('-r', '--channel2', help='r channel')\n\n    args = ap.parse_args()                \n    return args\n\n\nraw_filepath = \"./visphysics_data/double_pendulum_raw_data\"\ndata_filepath = \"./visphysics_data/double_pendulum_background_corrected_data\"\n\nnum_exps = len(os.listdir(raw_filepath))\nbg = np.zeros(3)\n\nargs = get_arguments()\nif args.channel0 is not None and args.channel1 is not None and args.channel2 is not None:\n    # user-specified background color\n    bg[0] = args.channel0\n    bg[1] = args.channel1\n    bg[2] = args.channel2\nelse:\n    # uniform background by averaging sampled pixels\n    percent = .05   # proportion of sampled frames\n    cnt = 0\n\n    print('Computing uniform background...')\n    for p_exp in tqdm(range(num_exps)):\n        seq_filepath = os.path.join(raw_filepath, str(p_exp))\n        num_frames = len(os.listdir(seq_filepath))\n        num_samples = int(percent * num_frames)\n        for p_frame in range(num_samples):\n            # read each frame\n            frame_path = os.path.join(seq_filepath, str(p_frame)+'.png')\n            frame = cv2.imread(frame_path)\n            # record all background pixels\n            for i in range(frame.shape[0]):\n                for j in range(frame.shape[1]):\n                    # manual threshold\n                    if (frame[i,j,2] > 130):\n                        bg += frame[i,j]\n                        cnt += 1\n\n    # average over all sampled pixels\n    bg = bg / cnt\n    print('Uniform Background R: ', bg[2], ' G: ', bg[1], ' B: ', bg[0])\n\n# apply the background to all data\nprint('Applying the background...')\nfor p_exp in tqdm(range(num_exps)):\n    seq_filepath = os.path.join(raw_filepath, str(p_exp))\n    data_seq_filepath = os.path.join(data_filepath, str(p_exp))\n    mkdir(data_seq_filepath)\n    num_frames = len(os.listdir(seq_filepath))\n\n    for p_frame in range(num_frames):\n        frame_path = os.path.join(seq_filepath, str(p_frame)+'.png')\n        frame = cv2.imread(frame_path)\n        for i in range(frame.shape[0]):\n            for j in range(frame.shape[1]):\n                # manual threshold\n                if (frame[i,j,2] > 130):\n                    frame[i,j] = bg\n        # write to data\n        cv2.imwrite(os.path.join(data_seq_filepath, str(p_frame)+'.png'), frame)"
  },
  {
    "path": "datainfo/collect/double_pendulum/split_data.py",
    "content": "import os\nimport shutil\nfrom tqdm import tqdm\nimport argparse\n\n'''\nThis script splits each sequence, which corresponds to a single experiment,\ninto subsequences with fixed length.\n'''\n\n\ndef mkdir(folder):\n    if os.path.exists(folder):\n        shutil.rmtree(folder)\n    os.makedirs(folder)\n\n\ndef get_arguments():\n    ap = argparse.ArgumentParser()\n    ap.add_argument('-l', '--len', type=int, default=60, help='length of each subsequence')\n\n    args = ap.parse_args()                \n    return args\n\n\nsrc_filepath = \"./visphysics_data/double_pendulum_background_corrected_data\"\ndst_filepath = \"./visphysics_data/double_pendulum\"\n\nargs = get_arguments()\n\nnum_exps = len(os.listdir(src_filepath))\nsubseq_cnt = 0\nframe_cnt = 0\nmkdir(os.path.join(dst_filepath, str(subseq_cnt)))\n\nfor p_exp in tqdm(range(num_exps)):\n    src_seq_filepath = os.path.join(src_filepath, str(p_exp))\n    num_frames = len(os.listdir(src_seq_filepath))\n    # ignore frames at the end of the sequence\n    num_frames = int(num_frames/args.len) * args.len\n\n    for p_frame in range(num_frames):\n        src_path = os.path.join(src_seq_filepath, str(p_frame)+'.png')\n        dst_path = os.path.join(dst_filepath, str(subseq_cnt), str(frame_cnt)+'.png')\n        shutil.copyfile(src_path, dst_path)\n\n        frame_cnt += 1\n        if frame_cnt == args.len:\n            subseq_cnt += 1\n            frame_cnt = 0\n            mkdir(os.path.join(dst_filepath, str(subseq_cnt)))\n\n# remove the last empty folder\nshutil.rmtree(os.path.join(dst_filepath, str(subseq_cnt)))"
  },
  {
    "path": "datainfo/collect/elastic_pendulum/make_data.py",
    "content": "import os\nimport shutil\nimport numpy as np\nfrom tqdm import tqdm\nfrom PIL import Image, ImageDraw\nfrom scipy import linalg\nfrom scipy.integrate import solve_ivp\n\ndef mkdir(folder):\n    if os.path.exists(folder):\n        shutil.rmtree(folder)\n    os.makedirs(folder)\n\n\ndef engine(rng, num_frm, fps=60):\n    # physical parameters\n    L_0 = 0.205  # elastic pendulum unstretched length (m)\n    L = 0.179    # rigid pendulum rod length (m)\n    w = 0.038    # rigid pendulum rod width (m)\n    m = 0.110    # rigid pendulum mass (kg)\n    g = 9.81     # gravitational acceleration (m/s^2)\n    k = 40.0     # elastic constant (kg/s^2)\n\n    dt = 1.0 / fps\n    t_eval = np.arange(num_frm) * dt\n    \n    # solve equations of motion\n    # y = [theta_1, theta_2, z, vel_theta_1, vel_theta_2, vel_z]\n    def f(t, y):\n        La = L_0 + y[2]\n        Lb = 0.5 * L\n        I = (1. / 12.) * m * (w ** 2 + L ** 2)\n        Jc = L * np.cos(y[0] - y[1])\n        Js = L * np.sin(y[0] - y[1])\n        A = np.array([[La**2, 0.5*La*Jc, 0],\n                      [0.5*La*Jc, Lb**2+I/m, 0.5*Js],\n                      [0, 0.5*Js, 1]])\n        b = np.zeros(3)\n        b[0] = -0.5*La*Js*y[4]**2 - 2*La*y[3]*y[5] - g*La*np.sin(y[0])\n        b[1] = 0.5*La*Js*y[3]**2 - Jc*y[3]*y[5] - g*Lb*np.sin(y[1])\n        b[2] = La*y[3]**2 + 0.5*Jc*y[4]**2 + g*np.cos(y[0]) - (k/m)*y[2]\n        sol = linalg.solve(A, b)\n        return [y[3], y[4], y[5], sol[0], sol[1], sol[2]]\n\n    # run until the drawn pendulums are inside the image\n    rej = True\n    while rej:\n        initial_state = [rng.uniform(0, 2*np.pi), rng.uniform(0, 2*np.pi), rng.uniform(-0.04, 0.04), rng.uniform(-6, 6), rng.uniform(-10, 10), 0]\n        sol = solve_ivp(f, [t_eval[0], t_eval[-1]], initial_state, t_eval=t_eval, rtol=1e-6)\n        states = sol.y.T\n        lim_x = np.abs((L_0+states[:, 2])*np.sin(states[:, 0]) + L*np.sin(states[:, 1])) - (L_0 + L)\n        lim_y = np.abs((L_0+states[:, 2])*np.cos(states[:, 0]) + L*np.cos(states[:, 1])) - (L_0 + L)\n        rej = (np.max(lim_x) > 0.16) | (np.max(lim_y) > 0.16) | (np.min(states[:, 2]) < -0.08)\n    \n    return states\n\n\ndef draw_rect(im, col, top_x, top_y, w, h, theta):\n    x1 = top_x - w * np.cos(theta) / 2\n    y1 = top_y + w * np.sin(theta) / 2\n    x2 = x1 + w * np.cos(theta)\n    y2 = y1 - w * np.sin(theta)\n    x3 = x2 + h * np.sin(theta)\n    y3 = y2 + h * np.cos(theta)\n    x4 = x3 - w * np.cos(theta)\n    y4 = y3 + w * np.sin(theta)\n    pts = [(x1, y1), (x2, y2), (x3, y3), (x4, y4)]\n\n    draw = ImageDraw.Draw(im)\n    draw.polygon(pts, fill=col)\n\n\ndef render(theta_1, theta_2, z):\n    bg_color = (215, 205, 192)\n    pd1_color = (63, 66, 85)\n    pd2_color = (17, 93, 234)\n\n    im = Image.new('RGB', (1600, 1600), bg_color)\n    center = (800, 800)\n\n    w1, w2, h1, h2 = 90, 90, 300, 250\n    L_0 = 0.205\n    h1 *= (1 + z / L_0)\n    pd1_end_x = center[0] + (h1-35) * np.sin(theta_1)\n    pd1_end_y = center[1] + (h1-35) * np.cos(theta_1)\n\n    # pd1 may hide pd2\n    draw_rect(im, pd2_color, pd1_end_x, pd1_end_y, w2, h2, theta_2)\n    draw_rect(im, pd1_color, center[0], center[1], w1, h1, theta_1)\n\n    im = im.resize((128, 128))\n    return im\n\n\ndef make_data(data_filepath, num_seq, num_frm, seed=0):\n    mkdir(data_filepath)\n    rng = np.random.default_rng(seed)\n    states = np.zeros((num_seq, num_frm, 6))\n\n    for n in tqdm(range(num_seq)):\n        seq_filepath = os.path.join(data_filepath, str(n))\n        mkdir(seq_filepath)\n        states[n, :, :] = engine(rng, num_frm)\n        for k in range(num_frm):\n            im = render(states[n, k, 0], states[n, k, 1], states[n, k, 2])\n            im.save(os.path.join(seq_filepath, str(k)+'.png'))\n\n    np.save(os.path.join(data_filepath, 'states.npy'), states)\n\n\nif __name__ == '__main__':\n    data_filepath = '/data/kuang/visphysics_data/elastic_pendulum'\n    make_data(data_filepath, num_seq=1200, num_frm=60)"
  },
  {
    "path": "datainfo/collect/fire/split_data.py",
    "content": "import os\nimport shutil\nfrom tqdm import tqdm\nimport argparse\n\n'''\nThis script splits each sequence, which corresponds to a single experiment,\ninto subsequences with fixed length.\n'''\n\n\ndef mkdir(folder):\n    if os.path.exists(folder):\n        shutil.rmtree(folder)\n    os.makedirs(folder)\n\n\ndef get_arguments():\n    ap = argparse.ArgumentParser()\n    ap.add_argument('-l', '--len', type=int, default=60, help='length of each subsequence')\n\n    args = ap.parse_args()                \n    return args\n\n\nsrc_train_filepath =  \"./visphysics_data/fire/train\"\nsrc_test_filepath = \"./visphysics_data/fire/test\"\ndst_filepath = \"./visphysics_data/fire_60\"\n\nargs = get_arguments()\n\nnum_train_exps = len(os.listdir(src_train_filepath))\nnum_test_exps = len(os.listdir(src_test_filepath))\nsubseq_cnt = 0\nframe_cnt = 0\nmkdir(os.path.join(dst_filepath, str(subseq_cnt)))\n\n# train\nfor p_exp in tqdm(range(num_train_exps)):\n    src_seq_filepath = os.path.join(src_train_filepath, str(p_exp))\n    num_frames = len(os.listdir(src_seq_filepath))\n    # ignore frames at the end of the sequence\n    num_frames = int(num_frames/args.len) * args.len\n\n    for p_frame in range(num_frames):\n        src_path = os.path.join(src_seq_filepath, str(p_frame)+'.jpg')\n        dst_path = os.path.join(dst_filepath, str(subseq_cnt), str(frame_cnt)+'.jpg')\n        shutil.copyfile(src_path, dst_path)\n\n        frame_cnt += 1\n        if frame_cnt == args.len:\n            subseq_cnt += 1\n            frame_cnt = 0\n            mkdir(os.path.join(dst_filepath, str(subseq_cnt)))\n\n# test\nfor p_exp in tqdm(range(num_test_exps)):\n    src_seq_filepath = os.path.join(src_test_filepath, str(p_exp))\n    num_frames = len(os.listdir(src_seq_filepath))\n    # ignore frames at the end of the sequence\n    num_frames = int(num_frames/args.len) * args.len\n\n    for p_frame in range(num_frames):\n        src_path = os.path.join(src_seq_filepath, str(p_frame)+'.png')\n        dst_path = os.path.join(dst_filepath, str(subseq_cnt), str(frame_cnt)+'.png')\n        shutil.copyfile(src_path, dst_path)\n\n        frame_cnt += 1\n        if frame_cnt == args.len:\n            subseq_cnt += 1\n            frame_cnt = 0\n            mkdir(os.path.join(dst_filepath, str(subseq_cnt)))\n\n# remove the last empty folder\nshutil.rmtree(os.path.join(dst_filepath, str(subseq_cnt)))"
  },
  {
    "path": "datainfo/collect/lava_lamp/split_data.py",
    "content": "import os\nimport shutil\nfrom tqdm import tqdm\nimport argparse\n\n'''\nThis script splits each sequence, which corresponds to a single experiment,\ninto subsequences with fixed length.\n'''\n\n\ndef mkdir(folder):\n    if os.path.exists(folder):\n        shutil.rmtree(folder)\n    os.makedirs(folder)\n\n\ndef get_arguments():\n    ap = argparse.ArgumentParser()\n    ap.add_argument('-l', '--len', type=int, default=60, help='length of each subsequence')\n\n    args = ap.parse_args()                \n    return args\n\n\nsrc_train_filepath =  \"./visphysics_data/lavalamp/raw_data_train\"\nsrc_test_filepath = \"./visphysics_data/lavalamp/raw_data_test\"\ndst_filepath = \"./visphysics_data/lavalamp_60\"\n\nargs = get_arguments()\n\nnum_train_exps = len(os.listdir(src_train_filepath))\nnum_test_exps = len(os.listdir(src_test_filepath))\nsubseq_cnt = 0\nframe_cnt = 0\nmkdir(os.path.join(dst_filepath, str(subseq_cnt)))\n\n# train\nsrc_seq_filepath = src_train_filepath\nnum_frames = len(os.listdir(src_seq_filepath))\n# ignore frames at the end of the sequence\nnum_frames = int(num_frames/args.len) * args.len\n\nfor p_frame in range(num_frames):\n    src_path = os.path.join(src_seq_filepath, str(p_frame)+'.jpg')\n    dst_path = os.path.join(dst_filepath, str(subseq_cnt), str(frame_cnt)+'.jpg')\n    shutil.copyfile(src_path, dst_path)\n\n    frame_cnt += 1\n    if frame_cnt == args.len:\n        subseq_cnt += 1\n        frame_cnt = 0\n        mkdir(os.path.join(dst_filepath, str(subseq_cnt)))\n\n# test\nsrc_seq_filepath = src_test_filepath\nnum_frames = len(os.listdir(src_seq_filepath))\n# ignore frames at the end of the sequence\nnum_frames = int(num_frames/args.len) * args.len\n\nfor p_frame in range(num_frames):\n    src_path = os.path.join(src_seq_filepath, str(p_frame)+'.png')\n    dst_path = os.path.join(dst_filepath, str(subseq_cnt), str(frame_cnt)+'.png')\n    shutil.copyfile(src_path, dst_path)\n\n    frame_cnt += 1\n    if frame_cnt == args.len:\n        subseq_cnt += 1\n        frame_cnt = 0\n        mkdir(os.path.join(dst_filepath, str(subseq_cnt)))\n\n# remove the last empty folder\nshutil.rmtree(os.path.join(dst_filepath, str(subseq_cnt)))"
  },
  {
    "path": "datainfo/collect/reaction_diffusion/make_data.py",
    "content": "import os\nimport shutil\nfrom tqdm import tqdm\nimport numpy as np\nimport scipy.io as sio\nfrom PIL import Image\nfrom matplotlib import cm\n\n\ndef mkdir(folder):\n    if os.path.exists(folder):\n        shutil.rmtree(folder)\n    os.makedirs(folder)\n\n\ndef make_data(data_filepath, num_frm):\n    data = sio.loadmat('reaction_diffusion.mat')\n    mkdir(data_filepath)\n\n    num_seq = int(data['t'].size / num_frm)\n    u_min, u_max = -1, 1\n\n    for n in tqdm(range(num_seq)):\n        seq_filepath = os.path.join(data_filepath, str(n))\n        mkdir(seq_filepath)\n        for k in range(num_frm):\n            u = data['uf'][:, :, n*num_frm+k]\n            u = (u - u_min) / (u_max - u_min)\n            u = cm.viridis(u)[:,:,:3]\n            u = np.uint8(u*255)\n            im = Image.fromarray(u)\n            im.save(os.path.join(seq_filepath, str(k)+'.png'))\n\n\nif __name__ == '__main__':\n    data_filepath = 'reaction_diffusion'\n    make_data(data_filepath, 100)"
  },
  {
    "path": "datainfo/collect/reaction_diffusion/reaction_diffusion.m",
    "content": "clear all; close all; clc\n\n% lambda-omega reaction-diffusion system\n%  u_t = lam(A) u - ome(A) v + d1*(u_xx + u_yy) = 0\n%  v_t = ome(A) u + lam(A) v + d2*(v_xx + v_yy) = 0\n%\n%  A^2 = u^2 + v^2 and\n%  lam(A) = 1 - A^2\n%  ome(A) = -beta*A^2\n\n\nt=0:0.2:2000.2;\nd1=0.1; d2=0.1; beta=1.0;\nL=20; n=128; N=n*n;\nx2=linspace(-L/2,L/2,n+1); x=x2(1:n); y=x;\nkx=(2*pi/L)*[0:(n/2-1) -n/2:-1]; ky=kx;\n\n% INITIAL CONDITIONS\n\n[X,Y]=meshgrid(x,y);\n[KX,KY]=meshgrid(kx,ky);\nK2=KX.^2+KY.^2; K22=reshape(K2,N,1);\n\nm=1; % number of spirals\n\nf = exp(-.01*(X.^2+Y.^2)); % circular mask to get rid of boundaries\n\nu_ini=tanh(sqrt(X.^2+Y.^2)).*cos(m*angle(X+i*Y)-(sqrt(X.^2+Y.^2)));\nv_ini=tanh(sqrt(X.^2+Y.^2)).*sin(m*angle(X+i*Y)-(sqrt(X.^2+Y.^2)));\n% uf(:,:,1)=f.*u(:,:,1);\n% vf(:,:,1)=f.*v(:,:,1);\n\n% REACTION-DIFFUSION\nuvt=[reshape(fft2(u_ini),1,N) reshape(fft2(v_ini),1,N)].';\nuvt_rhs = reaction_diffusion_rhs(t(1),uvt,[],K22,d1,d2,beta,n,N);\n% du(:,:,1) = real(ifft2(reshape(uvt_rhs(1:N).',n,n)));\n% dv(:,:,1)= real(ifft2(reshape(uvt_rhs((N+1):(2*N)).',n,n)));\n\n[t,uvsol]=ode45('reaction_diffusion_rhs',t,uvt,[],K22,d1,d2,beta,n,N);\n\nu = zeros(length(x),length(y),length(t));\nv = zeros(length(x),length(y),length(t));\nuf = zeros(length(x), length(y), length(t));\nvf = zeros(length(x), length(y), length(t));\n\ndu = zeros(length(x),length(y),length(t));\ndv = zeros(length(x),length(y),length(t));\nduf = zeros(length(x),length(y),length(t));\ndvf = zeros(length(x),length(y),length(t));\n\nfor j=1:length(t)-1\nut=reshape((uvsol(j,1:N).'),n,n);\nvt=reshape((uvsol(j,(N+1):(2*N)).'),n,n);\nu(:,:,j+1)=real(ifft2(ut));\nv(:,:,j+1)=real(ifft2(vt));\n\nuvt_rhs = reaction_diffusion_rhs(t(j+1),uvsol(j,1:end).',[],K22,d1,d2,beta,n,N);\ndu(:,:,j+1)= real(ifft2(reshape(uvt_rhs(1:N).',n,n)));\ndv(:,:,j+1)= real(ifft2(reshape(uvt_rhs((N+1):(2*N)).',n,n)));\n\nuf(:,:,j+1)=f.*u(:,:,j+1);\nvf(:,:,j+1)=f.*v(:,:,j+1);\nduf(:,:,j+1)=f.*du(:,:,j+1);\ndvf(:,:,j+1)=f.*dv(:,:,j+1);\n\n% figure(1)\n% pcolor(x,y,v(:,:,j+1)); shading interp; colormap(hot); colorbar; drawnow; \nend\n\nt = t(3:end);\nuf = uf(:,:,3:end);\nvf = vf(:,:,3:end);\nduf = duf(:,:,3:end);\ndvf = dvf(:,:,3:end);\n\nsave('reaction_diffusion.mat','t','x','y','uf')\n\n%%\n% load reaction_diffusion_big\n% pcolor(x,y,u(:,:,end)); shading interp; colormap(hot)\n\n\n"
  },
  {
    "path": "datainfo/collect/reaction_diffusion/reaction_diffusion_rhs.m",
    "content": "function rhs=reaction_diffusion_rhs(t,uvt,dummy,K22,d1,d2,beta,n,N);\n\n% Calculate u and v terms\nut=reshape((uvt(1:N)),n,n);\nvt=reshape((uvt((N+1):(2*N))),n,n);\nu=real(ifft2(ut)); v=real(ifft2(vt));\n\n% Reaction Terms\nu3=u.^3; v3=v.^3; u2v=(u.^2).*v; uv2=u.*(v.^2);\nutrhs=reshape((fft2(u-u3-uv2+beta*u2v+beta*v3)),N,1);\nvtrhs=reshape((fft2(v-u2v-v3-beta*u3-beta*uv2)),N,1);\n\nrhs=[-d1*K22.*uvt(1:N)+utrhs\n     -d2*K22.*uvt(N+1:end)+vtrhs];\n\n"
  },
  {
    "path": "datainfo/collect/single_pendulum/make_data.py",
    "content": "import os\nimport shutil\nimport numpy as np\nfrom tqdm import tqdm\nfrom PIL import Image\nfrom scipy.integrate import solve_ivp\n\ndef mkdir(folder):\n    if os.path.exists(folder):\n        shutil.rmtree(folder)\n    os.makedirs(folder)\n\n\ndef engine(rng, num_frm, fps=60):\n    # parameters\n    m = 1.0\n    l = 0.5\n    g = 9.81\n    dt = 1.0 / fps\n    t_eval = np.arange(num_frm) * dt\n\n    # solve equations of motion\n    # y = [theta, vel_theta]\n    f = lambda t, y: [y[1], -3*g/(2*l) * np.sin(y[0])]\n    initial_state = [rng.uniform(0, 2*np.pi), rng.uniform(-10, 10)]\n    sol = solve_ivp(f, [t_eval[0], t_eval[-1]], initial_state, t_eval=t_eval, rtol=1e-6)\n    \n    states = sol.y.T\n    return states\n\n\ndef render(theta):\n    bg_color = (215, 205, 192)\n    im = Image.new('RGB', (800, 800), bg_color)\n    pendulum_im = Image.open('./pendulum.png')\n    im.paste(pendulum_im, (363, 400))\n    im = im.rotate(theta*180/np.pi, fillcolor=bg_color)\n    im = im.resize((128,128))\n    return im\n\n\ndef make_data(data_filepath, num_seq, num_frm, seed=0):\n    mkdir(data_filepath)\n    rng = np.random.default_rng(seed)\n    states = np.zeros((num_seq, num_frm, 2))\n\n    for n in tqdm(range(num_seq)):\n        seq_filepath = os.path.join(data_filepath, str(n))\n        mkdir(seq_filepath)\n        states[n, :, :] = engine(rng, num_frm)\n        for k in range(num_frm):\n            im = render(states[n, k, 0])\n            im.save(os.path.join(seq_filepath, str(k)+'.png'))\n\n    np.save(os.path.join(data_filepath, 'states.npy'), states)\n\n\nif __name__ == '__main__':\n    data_filepath = '/data/kuang/visphysics_data/single_pendulum'\n    make_data(data_filepath, num_seq=1200, num_frm=60)"
  },
  {
    "path": "datainfo/collect/utils/sort_vids.py",
    "content": "\n\nimport os\nimport shutil\nimport numpy as np\nfrom tqdm import tqdm\n\n\n\ndef mkdir(folder):\n    if os.path.exists(folder):\n        shutil.rmtree(folder)\n    os.makedirs(folder)\n\ndata_folder = '/home/cml/Downloads/visphysics_data'\nobject_lst = ['air_dancer', 'fire', 'swingstick_magnetic', 'swingstick_non_magnetic']\nnum_train_vids = [22, 1, 18, 68]\nnum_test_vids = [5, 1, 5, 17]\n\nnew_data_folder=  '/home/cml/Downloads/visphysics_data_new'\n\n\nfor i in tqdm(range(len(object_lst))):\n    obj = object_lst[i]\n    old_folder = os.path.join(data_folder, obj)\n    new_folder = os.path.join(new_data_folder, obj)\n    # train\n    num_train_frames = len(os.listdir(os.path.join(old_folder, 'raw_data_train')))\n    ids = np.array(list(range(num_train_frames)))\n    ids = np.split(ids, num_train_vids[i])\n    for j in tqdm(range(num_train_vids[i])):\n        new_vid_folder = os.path.join(new_folder, 'train', str(j))\n        mkdir(new_vid_folder)\n        this_ids = ids[j]\n        k = 0\n        for p_idx in this_ids:\n            src = os.path.join(old_folder, 'raw_data_train', str(p_idx) + '.jpg')\n            dest = os.path.join(new_vid_folder, str(k) + '.jpg')\n            shutil.copy(src, dest)\n            k = k + 1\n    # test\n    num_test_frames = len(os.listdir(os.path.join(old_folder, 'raw_data_test')))\n    ids = np.array(list(range(num_test_frames)))\n    ids = np.split(ids, num_test_vids[i])\n    for j in tqdm(range(num_test_vids[i])):\n        new_vid_folder = os.path.join(new_folder, 'test', str(j))\n        mkdir(new_vid_folder)\n        this_ids = ids[j]\n        k = 0\n        for p_idx in this_ids:\n            src = os.path.join(old_folder, 'raw_data_test', str(p_idx) + '.png')\n            dest = os.path.join(new_vid_folder, str(k) + '.png')\n            shutil.copy(src, dest)\n            k = k + 1\n\n\n\n"
  },
  {
    "path": "datainfo/double_pendulum/data_split_dict_1.json",
    "content": "{\n    \"test\": [\n        524,\n        1032,\n        960,\n        907,\n        977,\n        840,\n        877,\n        802,\n        392,\n        5,\n        746,\n        980,\n        623,\n        1068,\n        675,\n        1098,\n        931,\n        470,\n        361,\n        591,\n        867,\n        975,\n        352,\n        526,\n        414,\n        237,\n        561,\n        880,\n        942,\n        552,\n        1059,\n        789,\n        12,\n        1005,\n        464,\n        345,\n        348,\n        806,\n        631,\n        89,\n        961,\n        60,\n        1002,\n        758,\n        1046,\n        335,\n        221,\n        898,\n        177,\n        767,\n        751,\n        354,\n        848,\n        827,\n        497,\n        983,\n        70,\n        805,\n        1034,\n        1022,\n        581,\n        621,\n        388,\n        1039,\n        1072,\n        1025,\n        681,\n        247,\n        607,\n        380,\n        204,\n        852,\n        44,\n        593,\n        941,\n        448,\n        472,\n        707,\n        477,\n        1015,\n        896,\n        454,\n        59,\n        864,\n        443,\n        780,\n        18,\n        52,\n        45,\n        62,\n        650,\n        209,\n        468,\n        545,\n        912,\n        4,\n        886,\n        798,\n        58,\n        999,\n        192,\n        429,\n        777,\n        967,\n        920,\n        1014,\n        241,\n        522,\n        129,\n        275\n    ],\n    \"val\": [\n        456,\n        164,\n        736,\n        36,\n        149,\n        406,\n        1075,\n        230,\n        21,\n        1017,\n        442,\n        620,\n        214,\n        1096,\n        747,\n        259,\n        111,\n        264,\n        1089,\n        815,\n        431,\n        351,\n        395,\n        319,\n        24,\n        116,\n        485,\n        508,\n        329,\n        719,\n        465,\n        301,\n        728,\n        663,\n        279,\n        672,\n        172,\n        540,\n        1063,\n        163,\n        171,\n        71,\n        297,\n        1011,\n        189,\n        639,\n        944,\n        112,\n        1004,\n        255,\n        287,\n        773,\n        772,\n        14,\n        463,\n        17,\n        888,\n        85,\n        72,\n        689,\n        861,\n        33,\n        261,\n        836,\n        871,\n        816,\n        564,\n        817,\n        93,\n        881,\n        185,\n        598,\n        563,\n        181,\n        1079,\n        235,\n        823,\n        28,\n        614,\n        469,\n        339,\n        627,\n        1033,\n        638,\n        553,\n        551,\n        1,\n        1040,\n        424,\n        365,\n        832,\n        496,\n        423,\n        516,\n        963,\n        1019,\n        567,\n        583,\n        373,\n        890,\n        492,\n        57,\n        972,\n        436,\n        210,\n        574,\n        796,\n        531,\n        132,\n        828\n    ],\n    \"train\": [\n        391,\n        479,\n        51,\n        1077,\n        882,\n        625,\n        268,\n        197,\n        433,\n        635,\n        246,\n        685,\n        640,\n        803,\n        131,\n        398,\n        1070,\n        580,\n        924,\n        863,\n        962,\n        238,\n        39,\n        763,\n        653,\n        455,\n        608,\n        869,\n        680,\n        768,\n        634,\n        107,\n        514,\n        509,\n        884,\n        417,\n        219,\n        146,\n        311,\n        630,\n        996,\n        290,\n        80,\n        911,\n        366,\n        353,\n        130,\n        814,\n        1001,\n        765,\n        100,\n        565,\n        487,\n        498,\n        158,\n        147,\n        862,\n        887,\n        13,\n        529,\n        295,\n        444,\n        916,\n        795,\n        722,\n        260,\n        1016,\n        950,\n        609,\n        284,\n        895,\n        542,\n        535,\n        381,\n        450,\n        294,\n        750,\n        502,\n        494,\n        326,\n        556,\n        67,\n        1057,\n        384,\n        933,\n        143,\n        644,\n        362,\n        1084,\n        120,\n        16,\n        1094,\n        706,\n        788,\n        233,\n        905,\n        1036,\n        596,\n        733,\n        853,\n        418,\n        254,\n        119,\n        26,\n        559,\n        575,\n        549,\n        1088,\n        834,\n        966,\n        202,\n        434,\n        1095,\n        95,\n        308,\n        507,\n        590,\n        154,\n        988,\n        357,\n        1008,\n        425,\n        360,\n        234,\n        778,\n        819,\n        883,\n        90,\n        793,\n        611,\n        1028,\n        901,\n        812,\n        740,\n        432,\n        923,\n        68,\n        289,\n        167,\n        679,\n        784,\n        1093,\n        186,\n        671,\n        891,\n        849,\n        316,\n        729,\n        669,\n        213,\n        375,\n        801,\n        304,\n        991,\n        569,\n        9,\n        633,\n        603,\n        428,\n        190,\n        940,\n        642,\n        368,\n        771,\n        857,\n        282,\n        263,\n        419,\n        330,\n        1031,\n        420,\n        655,\n        876,\n        520,\n        1041,\n        544,\n        194,\n        791,\n        81,\n        122,\n        859,\n        3,\n        656,\n        1042,\n        201,\n        956,\n        753,\n        266,\n        262,\n        782,\n        243,\n        334,\n        54,\n        99,\n        341,\n        140,\n        280,\n        1027,\n        421,\n        787,\n        43,\n        401,\n        96,\n        115,\n        989,\n        1071,\n        973,\n        56,\n        734,\n        1061,\n        617,\n        850,\n        30,\n        928,\n        534,\n        573,\n        1067,\n        968,\n        949,\n        499,\n        343,\n        868,\n        990,\n        837,\n        242,\n        1092,\n        482,\n        1083,\n        992,\n        568,\n        350,\n        723,\n        969,\n        612,\n        467,\n        344,\n        810,\n        503,\n        1087,\n        693,\n        807,\n        619,\n        174,\n        717,\n        1076,\n        594,\n        29,\n        624,\n        586,\n        718,\n        336,\n        809,\n        926,\n        616,\n        441,\n        409,\n        1045,\n        1053,\n        954,\n        134,\n        451,\n        501,\n        537,\n        731,\n        533,\n        1010,\n        1006,\n        491,\n        708,\n        206,\n        688,\n        125,\n        927,\n        113,\n        770,\n        459,\n        709,\n        453,\n        390,\n        200,\n        713,\n        402,\n        906,\n        152,\n        538,\n        25,\n        379,\n        654,\n        766,\n        188,\n        267,\n        493,\n        661,\n        137,\n        776,\n        121,\n        764,\n        687,\n        430,\n        978,\n        570,\n        800,\n        47,\n        82,\n        701,\n        489,\n        959,\n        396,\n        762,\n        1020,\n        35,\n        216,\n        939,\n        50,\n        1007,\n        1026,\n        23,\n        948,\n        702,\n        438,\n        248,\n        987,\n        878,\n        77,\n        155,\n        278,\n        745,\n        135,\n        178,\n        998,\n        309,\n        37,\n        676,\n        759,\n        775,\n        870,\n        79,\n        1048,\n        539,\n        712,\n        699,\n        148,\n        412,\n        613,\n        986,\n        385,\n        34,\n        446,\n        484,\n        915,\n        207,\n        774,\n        979,\n        199,\n        657,\n        378,\n        269,\n        184,\n        371,\n        271,\n        856,\n        845,\n        394,\n        532,\n        359,\n        997,\n        673,\n        476,\n        739,\n        478,\n        236,\n        821,\n        515,\n        372,\n        665,\n        865,\n        695,\n        831,\n        965,\n        285,\n        879,\n        892,\n        958,\n        331,\n        1000,\n        666,\n        1090,\n        889,\n        374,\n        1091,\n        974,\n        276,\n        156,\n        179,\n        585,\n        643,\n        226,\n        935,\n        0,\n        714,\n        781,\n        790,\n        510,\n        1049,\n        257,\n        223,\n        447,\n        293,\n        651,\n        530,\n        825,\n        69,\n        332,\n        370,\n        1052,\n        548,\n        1065,\n        525,\n        6,\n        382,\n        291,\n        474,\n        726,\n        413,\n        55,\n        415,\n        452,\n        104,\n        196,\n        292,\n        756,\n        519,\n        1078,\n        215,\n        647,\n        716,\n        913,\n        779,\n        437,\n        1013,\n        582,\n        866,\n        748,\n        829,\n        571,\n        705,\n        299,\n        527,\n        139,\n        698,\n        277,\n        75,\n        606,\n        88,\n        1066,\n        838,\n        427,\n        124,\n        648,\n        321,\n        15,\n        725,\n        296,\n        486,\n        270,\n        97,\n        2,\n        286,\n        383,\n        193,\n        127,\n        808,\n        151,\n        338,\n        854,\n        483,\n        1080,\n        1050,\n        307,\n        677,\n        61,\n        108,\n        636,\n        700,\n        318,\n        684,\n        546,\n        166,\n        909,\n        84,\n        696,\n        1035,\n        462,\n        602,\n        481,\n        195,\n        760,\n        860,\n        123,\n        473,\n        435,\n        227,\n        517,\n        358,\n        900,\n        337,\n        495,\n        715,\n        315,\n        397,\n        785,\n        142,\n        936,\n        737,\n        622,\n        769,\n        660,\n        1074,\n        168,\n        393,\n        356,\n        403,\n        1018,\n        53,\n        833,\n        855,\n        659,\n        159,\n        180,\n        386,\n        641,\n        126,\n        87,\n        475,\n        169,\n        1058,\n        599,\n        844,\n        902,\n        208,\n        543,\n        918,\n        377,\n        523,\n        73,\n        595,\n        224,\n        281,\n        1024,\n        422,\n        1085,\n        490,\n        506,\n        724,\n        244,\n        231,\n        1064,\n        512,\n        953,\n        541,\n        952,\n        7,\n        994,\n        239,\n        662,\n        176,\n        410,\n        407,\n        730,\n        914,\n        572,\n        749,\n        141,\n        363,\n        63,\n        19,\n        841,\n        903,\n        1043,\n        1069,\n        830,\n        1038,\n        405,\n        11,\n        1023,\n        597,\n        449,\n        670,\n        678,\n        562,\n        440,\n        457,\n        683,\n        839,\n        253,\n        49,\n        161,\n        550,\n        946,\n        306,\n        182,\n        820,\n        566,\n        323,\n        32,\n        558,\n        145,\n        211,\n        1097,\n        300,\n        1012,\n        109,\n        312,\n        938,\n        144,\n        153,\n        183,\n        1037,\n        826,\n        743,\n        652,\n        513,\n        943,\n        157,\n        674,\n        505,\n        367,\n        921,\n        692,\n        22,\n        76,\n        847,\n        757,\n        930,\n        752,\n        804,\n        249,\n        20,\n        872,\n        250,\n        94,\n        792,\n        649,\n        1029,\n        874,\n        103,\n        342,\n        251,\n        310,\n        592,\n        682,\n        191,\n        799,\n        42,\n        668,\n        811,\n        232,\n        985,\n        658,\n        588,\n        92,\n        458,\n        91,\n        929,\n        738,\n        369,\n        252,\n        203,\n        314,\n        710,\n        554,\n        187,\n        265,\n        364,\n        480,\n        704,\n        632,\n        220,\n        256,\n        114,\n        466,\n        615,\n        324,\n        65,\n        64,\n        408,\n        320,\n        400,\n        460,\n        327,\n        610,\n        727,\n        10,\n        27,\n        908,\n        325,\n        667,\n        212,\n        102,\n        322,\n        488,\n        894,\n        741,\n        910,\n        555,\n        744,\n        445,\n        105,\n        842,\n        1056,\n        697,\n        919,\n        1009,\n        118,\n        165,\n        899,\n        600,\n        245,\n        897,\n        754,\n        843,\n        971,\n        947,\n        932,\n        686,\n        628,\n        1021,\n        735,\n        46,\n        110,\n        283,\n        1081,\n        893,\n        786,\n        577,\n        302,\n        993,\n        273,\n        83,\n        579,\n        229,\n        951,\n        584,\n        846,\n        824,\n        601,\n        629,\n        117,\n        349,\n        851,\n        150,\n        389,\n        74,\n        416,\n        40,\n        858,\n        813,\n        945,\n        1062,\n        797,\n        500,\n        732,\n        618,\n        783,\n        298,\n        904,\n        1054,\n        346,\n        376,\n        917,\n        995,\n        957,\n        340,\n        274,\n        794,\n        964,\n        170,\n        173,\n        136,\n        86,\n        41,\n        742,\n        66,\n        240,\n        1082,\n        970,\n        547,\n        703,\n        922,\n        560,\n        1051,\n        98,\n        818,\n        272,\n        218,\n        439,\n        347,\n        138,\n        576,\n        1047,\n        955,\n        160,\n        885,\n        288,\n        411,\n        626,\n        333,\n        934,\n        511,\n        981,\n        303,\n        399,\n        1055,\n        106,\n        504,\n        198,\n        605,\n        1073,\n        690,\n        587,\n        1044,\n        101,\n        355,\n        205,\n        387,\n        835,\n        521,\n        637,\n        720,\n        1086,\n        175,\n        471,\n        976,\n        222,\n        604,\n        38,\n        984,\n        8,\n        133,\n        258,\n        578,\n        426,\n        162,\n        761,\n        873,\n        317,\n        78,\n        937,\n        313,\n        48,\n        217,\n        128,\n        305,\n        755,\n        1030,\n        875,\n        646,\n        1003,\n        328,\n        822,\n        589,\n        691,\n        404,\n        31,\n        664,\n        536,\n        228,\n        461,\n        528,\n        711,\n        925,\n        645,\n        225,\n        1060,\n        557,\n        982,\n        694,\n        518,\n        721\n    ]\n}"
  },
  {
    "path": "datainfo/double_pendulum/data_split_dict_2.json",
    "content": "{\n    \"test\": [\n        24,\n        688,\n        255,\n        176,\n        368,\n        371,\n        58,\n        32,\n        777,\n        734,\n        919,\n        433,\n        61,\n        901,\n        966,\n        215,\n        844,\n        250,\n        272,\n        874,\n        139,\n        534,\n        772,\n        108,\n        937,\n        896,\n        698,\n        232,\n        605,\n        1066,\n        50,\n        668,\n        588,\n        60,\n        217,\n        391,\n        17,\n        699,\n        154,\n        750,\n        1001,\n        425,\n        638,\n        832,\n        1032,\n        621,\n        633,\n        982,\n        1082,\n        340,\n        664,\n        454,\n        996,\n        935,\n        718,\n        1062,\n        931,\n        1057,\n        1025,\n        1020,\n        571,\n        1003,\n        511,\n        944,\n        818,\n        330,\n        1060,\n        741,\n        724,\n        1079,\n        849,\n        912,\n        372,\n        1052,\n        736,\n        1065,\n        1044,\n        279,\n        355,\n        665,\n        361,\n        48,\n        472,\n        483,\n        363,\n        336,\n        867,\n        778,\n        652,\n        952,\n        745,\n        56,\n        1090,\n        549,\n        1028,\n        911,\n        761,\n        1042,\n        805,\n        882,\n        324,\n        73,\n        434,\n        515,\n        631,\n        346,\n        739,\n        173,\n        187,\n        115\n    ],\n    \"val\": [\n        810,\n        639,\n        233,\n        1017,\n        82,\n        256,\n        75,\n        455,\n        731,\n        238,\n        252,\n        1041,\n        725,\n        520,\n        237,\n        650,\n        464,\n        98,\n        174,\n        165,\n        134,\n        34,\n        259,\n        995,\n        686,\n        143,\n        1049,\n        716,\n        18,\n        1048,\n        429,\n        620,\n        268,\n        265,\n        349,\n        924,\n        147,\n        335,\n        527,\n        498,\n        402,\n        799,\n        598,\n        530,\n        986,\n        895,\n        807,\n        458,\n        989,\n        104,\n        858,\n        323,\n        703,\n        676,\n        95,\n        230,\n        484,\n        157,\n        722,\n        636,\n        883,\n        411,\n        773,\n        270,\n        923,\n        46,\n        757,\n        619,\n        784,\n        564,\n        459,\n        315,\n        31,\n        500,\n        345,\n        292,\n        1098,\n        765,\n        760,\n        642,\n        630,\n        352,\n        4,\n        37,\n        155,\n        253,\n        813,\n        44,\n        603,\n        394,\n        1,\n        708,\n        535,\n        188,\n        752,\n        160,\n        958,\n        1033,\n        130,\n        261,\n        382,\n        21,\n        940,\n        746,\n        41,\n        25,\n        69,\n        977,\n        117,\n        84\n    ],\n    \"train\": [\n        521,\n        97,\n        212,\n        436,\n        465,\n        835,\n        410,\n        195,\n        591,\n        288,\n        988,\n        828,\n        450,\n        127,\n        42,\n        93,\n        92,\n        970,\n        873,\n        239,\n        424,\n        1086,\n        956,\n        310,\n        266,\n        824,\n        142,\n        717,\n        119,\n        727,\n        1008,\n        357,\n        871,\n        584,\n        64,\n        396,\n        817,\n        406,\n        57,\n        2,\n        681,\n        9,\n        249,\n        864,\n        512,\n        198,\n        245,\n        40,\n        473,\n        1089,\n        299,\n        182,\n        601,\n        379,\n        321,\n        955,\n        763,\n        486,\n        766,\n        22,\n        906,\n        744,\n        171,\n        316,\n        702,\n        467,\n        370,\n        898,\n        274,\n        820,\n        1015,\n        644,\n        443,\n        707,\n        13,\n        576,\n        764,\n        109,\n        838,\n        559,\n        158,\n        26,\n        670,\n        721,\n        329,\n        789,\n        798,\n        802,\n        767,\n        133,\n        431,\n        539,\n        325,\n        672,\n        485,\n        618,\n        140,\n        135,\n        438,\n        692,\n        408,\n        608,\n        997,\n        640,\n        29,\n        344,\n        572,\n        421,\n        748,\n        6,\n        327,\n        1039,\n        235,\n        308,\n        796,\n        635,\n        290,\n        581,\n        713,\n        178,\n        612,\n        943,\n        568,\n        866,\n        207,\n        546,\n        519,\n        917,\n        889,\n        214,\n        339,\n        111,\n        448,\n        35,\n        63,\n        248,\n        282,\n        405,\n        163,\n        629,\n        782,\n        226,\n        659,\n        850,\n        902,\n        487,\n        1058,\n        592,\n        540,\n        596,\n        505,\n        831,\n        801,\n        575,\n        1046,\n        800,\n        833,\n        963,\n        556,\n        758,\n        504,\n        347,\n        236,\n        206,\n        441,\n        682,\n        1085,\n        547,\n        1040,\n        507,\n        87,\n        1094,\n        580,\n        975,\n        604,\n        648,\n        569,\n        294,\n        7,\n        916,\n        607,\n        788,\n        228,\n        281,\n        573,\n        376,\n        219,\n        915,\n        241,\n        167,\n        208,\n        1080,\n        216,\n        529,\n        1075,\n        94,\n        690,\n        907,\n        542,\n        578,\n        146,\n        120,\n        968,\n        49,\n        862,\n        211,\n        680,\n        447,\n        815,\n        586,\n        1068,\n        488,\n        693,\n        306,\n        271,\n        269,\n        314,\n        43,\n        792,\n        797,\n        976,\n        417,\n        954,\n        191,\n        499,\n        151,\n        729,\n        854,\n        848,\n        552,\n        781,\n        307,\n        407,\n        826,\n        881,\n        985,\n        1097,\n        1069,\n        397,\n        423,\n        508,\n        168,\n        76,\n        998,\n        627,\n        439,\n        860,\n        258,\n        229,\n        105,\n        645,\n        590,\n        751,\n        377,\n        759,\n        791,\n        597,\n        80,\n        106,\n        887,\n        891,\n        386,\n        753,\n        1054,\n        322,\n        100,\n        809,\n        965,\n        300,\n        843,\n        661,\n        913,\n        899,\n        617,\n        283,\n        855,\n        567,\n        99,\n        842,\n        926,\n        737,\n        574,\n        634,\n        967,\n        440,\n        987,\n        948,\n        637,\n        960,\n        886,\n        416,\n        562,\n        51,\n        513,\n        1019,\n        662,\n        732,\n        803,\n        66,\n        356,\n        460,\n        419,\n        897,\n        152,\n        932,\n        193,\n        359,\n        945,\n        606,\n        476,\n        583,\n        660,\n        624,\n        786,\n        369,\n        466,\n        506,\n        110,\n        557,\n        403,\n        1012,\n        351,\n        378,\n        1091,\n        981,\n        276,\n        415,\n        0,\n        251,\n        427,\n        953,\n        432,\n        942,\n        770,\n        683,\n        616,\n        566,\n        877,\n        787,\n        677,\n        456,\n        257,\n        70,\n        15,\n        286,\n        71,\n        929,\n        33,\n        123,\n        689,\n        756,\n        518,\n        332,\n        827,\n        47,\n        780,\n        614,\n        790,\n        12,\n        876,\n        112,\n        632,\n        541,\n        671,\n        129,\n        845,\n        694,\n        234,\n        156,\n        201,\n        482,\n        490,\n        884,\n        348,\n        1022,\n        1021,\n        302,\n        380,\n        921,\n        738,\n        231,\n        1084,\n        197,\n        210,\n        723,\n        1016,\n        354,\n        205,\n        964,\n        1038,\n        517,\n        755,\n        593,\n        1064,\n        305,\n        918,\n        278,\n        691,\n        1005,\n        1006,\n        242,\n        675,\n        200,\n        27,\n        116,\n        492,\n        430,\n        1073,\n        1030,\n        949,\n        687,\n        59,\n        295,\n        267,\n        920,\n        613,\n        1009,\n        398,\n        227,\n        880,\n        890,\n        90,\n        1081,\n        951,\n        785,\n        865,\n        331,\n        190,\n        172,\n        23,\n        312,\n        1061,\n        122,\n        362,\n        126,\n        646,\n        175,\n        45,\n        442,\n        678,\n        563,\n        446,\n        128,\n        169,\n        857,\n        775,\n        719,\n        1070,\n        991,\n        806,\n        623,\n        754,\n        77,\n        615,\n        684,\n        145,\n        463,\n        199,\n        304,\n        714,\n        666,\n        816,\n        863,\n        277,\n        868,\n        223,\n        333,\n        841,\n        878,\n        532,\n        808,\n        469,\n        39,\n        183,\n        85,\n        179,\n        53,\n        14,\n        65,\n        861,\n        704,\n        565,\n        1010,\n        457,\n        973,\n        244,\n        225,\n        647,\n        461,\n        132,\n        859,\n        247,\n        202,\n        936,\n        972,\n        962,\n        343,\n        544,\n        318,\n        1007,\n        149,\n        747,\n        959,\n        611,\n        385,\n        930,\n        479,\n        360,\n        445,\n        365,\n        384,\n        1092,\n        328,\n        287,\n        350,\n        409,\n        387,\n        743,\n        555,\n        1083,\n        1011,\n        728,\n        189,\n        776,\n        194,\n        480,\n        243,\n        101,\n        162,\n        334,\n        1018,\n        643,\n        649,\n        137,\n        847,\n        298,\n        452,\n        749,\n        170,\n        994,\n        554,\n        730,\n        79,\n        1043,\n        326,\n        218,\n        373,\n        388,\n        1037,\n        1045,\n        62,\n        284,\n        971,\n        705,\n        836,\n        888,\n        983,\n        622,\n        91,\n        516,\n        153,\n        543,\n        957,\n        837,\n        892,\n        933,\n        922,\n        673,\n        319,\n        341,\n        311,\n        651,\n        264,\n        783,\n        309,\n        551,\n        742,\n        136,\n        510,\n        879,\n        946,\n        453,\n        577,\n        654,\n        550,\n        900,\n        814,\n        969,\n        641,\n        570,\n        289,\n        10,\n        1095,\n        301,\n        653,\n        390,\n        1051,\n        166,\n        55,\n        138,\n        978,\n        908,\n        1034,\n        1067,\n        81,\n        28,\n        471,\n        514,\n        273,\n        470,\n        762,\n        496,\n        829,\n        67,\n        658,\n        795,\n        3,\n        358,\n        74,\n        209,\n        478,\n        177,\n        579,\n        525,\n        667,\n        493,\n        834,\n        8,\n        141,\n        626,\n        501,\n        779,\n        594,\n        600,\n        669,\n        204,\n        905,\n        812,\n        856,\n        220,\n        395,\n        910,\n        582,\n        221,\n        655,\n        1047,\n        823,\n        181,\n        822,\n        1026,\n        846,\n        320,\n        11,\n        894,\n        1024,\n        412,\n        192,\n        938,\n        545,\n        934,\n        83,\n        338,\n        560,\n        875,\n        1072,\n        710,\n        925,\n        1002,\n        414,\n        280,\n        494,\n        54,\n        522,\n        999,\n        184,\n        392,\n        609,\n        30,\n        451,\n        477,\n        148,\n        587,\n        367,\n        413,\n        240,\n        1071,\n        161,\n        1096,\n        263,\n        526,\n        254,\n        528,\n        1074,\n        720,\n        872,\n        553,\n        159,\n        114,\n        293,\n        536,\n        992,\n        449,\n        337,\n        711,\n        903,\n        774,\n        399,\n        735,\n        131,\n        38,\n        663,\n        697,\n        88,\n        68,\n        426,\n        150,\n        558,\n        180,\n        364,\n        72,\n        1014,\n        589,\n        599,\n        86,\n        715,\n        695,\n        712,\n        437,\n        561,\n        1055,\n        610,\n        1078,\n        974,\n        852,\n        769,\n        303,\n        696,\n        503,\n        909,\n        1063,\n        489,\n        870,\n        495,\n        275,\n        383,\n        144,\n        468,\n        366,\n        297,\n        509,\n        1013,\n        947,\n        1050,\n        481,\n        674,\n        1035,\n        1059,\n        1087,\n        709,\n        840,\n        400,\n        656,\n        342,\n        1093,\n        939,\n        125,\n        853,\n        260,\n        124,\n        404,\n        353,\n        771,\n        904,\n        794,\n        733,\n        401,\n        979,\n        585,\n        118,\n        196,\n        502,\n        941,\n        224,\n        78,\n        740,\n        804,\n        811,\n        927,\n        291,\n        5,\n        885,\n        628,\n        602,\n        213,\n        474,\n        497,\n        1077,\n        420,\n        1029,\n        462,\n        16,\n        990,\n        793,\n        203,\n        313,\n        851,\n        103,\n        422,\n        839,\n        1053,\n        950,\n        381,\n        706,\n        296,\n        375,\n        625,\n        121,\n        531,\n        19,\n        374,\n        491,\n        821,\n        679,\n        96,\n        185,\n        830,\n        537,\n        428,\n        52,\n        1088,\n        595,\n        980,\n        523,\n        435,\n        444,\n        825,\n        1036,\n        1004,\n        1076,\n        701,\n        1023,\n        389,\n        657,\n        548,\n        317,\n        914,\n        475,\n        685,\n        533,\n        984,\n        222,\n        107,\n        893,\n        869,\n        186,\n        20,\n        102,\n        928,\n        246,\n        89,\n        1056,\n        524,\n        113,\n        164,\n        418,\n        393,\n        36,\n        1027,\n        961,\n        768,\n        538,\n        285,\n        1000,\n        700,\n        262,\n        993,\n        726,\n        1031,\n        819\n    ]\n}"
  },
  {
    "path": "datainfo/double_pendulum/data_split_dict_3.json",
    "content": "{\n    \"test\": [\n        887,\n        659,\n        395,\n        532,\n        890,\n        471,\n        385,\n        386,\n        882,\n        918,\n        141,\n        981,\n        368,\n        321,\n        347,\n        888,\n        1003,\n        43,\n        706,\n        159,\n        269,\n        625,\n        298,\n        417,\n        994,\n        202,\n        971,\n        32,\n        548,\n        614,\n        110,\n        78,\n        7,\n        317,\n        1065,\n        483,\n        571,\n        677,\n        773,\n        92,\n        90,\n        243,\n        850,\n        1082,\n        601,\n        41,\n        1085,\n        840,\n        136,\n        704,\n        181,\n        990,\n        987,\n        129,\n        254,\n        583,\n        546,\n        432,\n        213,\n        668,\n        334,\n        572,\n        58,\n        689,\n        475,\n        834,\n        718,\n        790,\n        1038,\n        862,\n        1076,\n        893,\n        528,\n        444,\n        1013,\n        278,\n        73,\n        199,\n        748,\n        274,\n        910,\n        808,\n        874,\n        793,\n        968,\n        551,\n        63,\n        616,\n        87,\n        326,\n        131,\n        31,\n        798,\n        1071,\n        310,\n        474,\n        308,\n        813,\n        975,\n        963,\n        392,\n        479,\n        531,\n        960,\n        26,\n        134,\n        970,\n        757,\n        267,\n        487\n    ],\n    \"val\": [\n        223,\n        897,\n        759,\n        344,\n        81,\n        173,\n        876,\n        829,\n        448,\n        229,\n        853,\n        691,\n        341,\n        329,\n        930,\n        104,\n        99,\n        714,\n        665,\n        445,\n        694,\n        191,\n        335,\n        244,\n        275,\n        823,\n        669,\n        227,\n        512,\n        1022,\n        1094,\n        752,\n        700,\n        582,\n        27,\n        832,\n        1056,\n        107,\n        996,\n        806,\n        307,\n        270,\n        988,\n        864,\n        936,\n        776,\n        320,\n        189,\n        372,\n        1039,\n        327,\n        615,\n        606,\n        305,\n        467,\n        643,\n        257,\n        378,\n        21,\n        692,\n        62,\n        974,\n        22,\n        501,\n        755,\n        285,\n        723,\n        623,\n        950,\n        939,\n        695,\n        361,\n        477,\n        340,\n        642,\n        648,\n        61,\n        1037,\n        647,\n        603,\n        630,\n        995,\n        20,\n        322,\n        593,\n        425,\n        807,\n        11,\n        1005,\n        561,\n        1083,\n        533,\n        264,\n        447,\n        1035,\n        958,\n        1030,\n        732,\n        737,\n        649,\n        441,\n        277,\n        519,\n        830,\n        1088,\n        635,\n        105,\n        1050,\n        697,\n        609\n    ],\n    \"train\": [\n        401,\n        148,\n        464,\n        711,\n        957,\n        1084,\n        420,\n        1007,\n        318,\n        1074,\n        579,\n        972,\n        434,\n        2,\n        770,\n        220,\n        1010,\n        540,\n        192,\n        740,\n        857,\n        788,\n        545,\n        408,\n        352,\n        163,\n        469,\n        118,\n        758,\n        534,\n        506,\n        587,\n        504,\n        396,\n        863,\n        760,\n        231,\n        794,\n        1059,\n        1049,\n        768,\n        655,\n        456,\n        279,\n        313,\n        459,\n        769,\n        371,\n        80,\n        747,\n        66,\n        67,\n        525,\n        804,\n        521,\n        144,\n        1081,\n        894,\n        1060,\n        1066,\n        682,\n        439,\n        476,\n        142,\n        576,\n        620,\n        38,\n        672,\n        909,\n        345,\n        328,\n        1019,\n        193,\n        1087,\n        713,\n        179,\n        413,\n        156,\n        607,\n        640,\n        898,\n        1046,\n        511,\n        1027,\n        253,\n        558,\n        135,\n        781,\n        1098,\n        599,\n        945,\n        1073,\n        157,\n        905,\n        221,\n        197,\n        177,\n        59,\n        219,\n        154,\n        178,\n        486,\n        421,\n        727,\n        64,\n        53,\n        402,\n        1097,\n        355,\n        751,\n        228,\n        646,\n        631,\n        852,\n        452,\n        627,\n        1052,\n        309,\n        292,\n        462,\n        325,\n        784,\n        555,\n        83,\n        225,\n        149,\n        350,\n        1025,\n        1031,\n        468,\n        427,\n        969,\n        49,\n        846,\n        836,\n        262,\n        473,\n        1012,\n        520,\n        12,\n        573,\n        60,\n        212,\n        1048,\n        973,\n        820,\n        671,\n        374,\n        460,\n        249,\n        859,\n        171,\n        172,\n        198,\n        218,\n        407,\n        746,\n        908,\n        502,\n        112,\n        1036,\n        450,\n        594,\n        1089,\n        999,\n        931,\n        809,\n        301,\n        137,\n        236,\n        819,\n        639,\n        753,\n        843,\n        1002,\n        1069,\n        75,\n        207,\n        100,\n        214,\n        300,\n        346,\n        720,\n        8,\n        578,\n        89,\n        188,\n        272,\n        1070,\n        418,\n        84,\n        725,\n        167,\n        238,\n        875,\n        800,\n        470,\n        390,\n        1055,\n        835,\n        211,\n        312,\n        217,\n        656,\n        0,\n        405,\n        151,\n        941,\n        1026,\n        174,\n        375,\n        586,\n        304,\n        814,\n        315,\n        265,\n        855,\n        1072,\n        929,\n        143,\n        565,\n        778,\n        485,\n        779,\n        337,\n        201,\n        562,\n        399,\n        186,\n        709,\n        500,\n        608,\n        1080,\n        362,\n        29,\n        743,\n        621,\n        42,\n        48,\n        121,\n        949,\n        1040,\n        696,\n        266,\n        557,\n        685,\n        6,\n        690,\n        624,\n        10,\n        1044,\n        955,\n        952,\n        719,\n        424,\n        16,\n        680,\n        162,\n        415,\n        287,\n        575,\n        821,\n        411,\n        363,\n        780,\n        503,\n        754,\n        71,\n        14,\n        86,\n        935,\n        702,\n        394,\n        299,\n        717,\n        721,\n        921,\n        589,\n        977,\n        693,\n        658,\n        889,\n        916,\n        496,\n        114,\n        296,\n        185,\n        938,\n        842,\n        712,\n        792,\n        772,\n        19,\n        233,\n        454,\n        224,\n        1018,\n        896,\n        789,\n        966,\n        155,\n        698,\n        106,\n        216,\n        687,\n        872,\n        242,\n        449,\n        404,\n        82,\n        535,\n        480,\n        209,\n        998,\n        482,\n        629,\n        248,\n        113,\n        1041,\n        705,\n        1016,\n        673,\n        775,\n        40,\n        17,\n        370,\n        1034,\n        68,\n        841,\n        168,\n        165,\n        507,\n        9,\n        801,\n        762,\n        39,\n        259,\n        364,\n        316,\n        633,\n        803,\n        25,\n        319,\n        666,\n        239,\n        708,\n        914,\n        728,\n        145,\n        1057,\n        412,\n        815,\n        472,\n        518,\n        699,\n        108,\n        619,\n        397,\n        735,\n        564,\n        652,\n        684,\n        745,\n        251,\n        119,\n        338,\n        951,\n        943,\n        379,\n        97,\n        797,\n        387,\n        559,\n        1061,\n        466,\n        458,\n        351,\n        844,\n        1032,\n        324,\n        158,\n        679,\n        282,\n        161,\n        443,\n        208,\n        828,\n        15,\n        899,\n        605,\n        526,\n        103,\n        670,\n        653,\n        1045,\n        543,\n        451,\n        517,\n        597,\n        724,\n        976,\n        303,\n        65,\n        516,\n        1000,\n        574,\n        343,\n        1047,\n        290,\n        96,\n        357,\n        674,\n        416,\n        654,\n        667,\n        288,\n        617,\n        1058,\n        1095,\n        513,\n        437,\n        77,\n        57,\n        764,\n        498,\n        457,\n        736,\n        414,\n        822,\n        993,\n        206,\n        170,\n        928,\n        263,\n        393,\n        356,\n        681,\n        831,\n        628,\n        774,\n        937,\n        816,\n        913,\n        196,\n        138,\n        956,\n        750,\n        660,\n        200,\n        923,\n        294,\n        1064,\n        283,\n        510,\n        1028,\n        366,\n        707,\n        810,\n        947,\n        650,\n        595,\n        610,\n        246,\n        234,\n        95,\n        331,\n        153,\n        284,\n        854,\n        932,\n        638,\n        1001,\n        783,\n        169,\n        481,\n        72,\n        286,\n        55,\n        1075,\n        817,\n        901,\n        911,\n        455,\n        926,\n        91,\n        76,\n        741,\n        373,\n        489,\n        756,\n        777,\n        380,\n        505,\n        927,\n        967,\n        596,\n        367,\n        536,\n        180,\n        560,\n        252,\n        891,\n        1051,\n        1009,\n        585,\n        289,\n        538,\n        552,\n        946,\n        1024,\n        917,\n        644,\n        730,\n        591,\n        877,\n        377,\n        686,\n        895,\n        256,\n        989,\n        115,\n        442,\n        1020,\n        281,\n        924,\n        676,\n        786,\n        400,\n        742,\n        215,\n        984,\n        497,\n        388,\n        446,\n        339,\n        342,\n        30,\n        824,\n        997,\n        463,\n        365,\n        839,\n        851,\n        920,\n        664,\n        120,\n        866,\n        93,\n        881,\n        922,\n        925,\n        641,\n        884,\n        261,\n        845,\n        900,\n        140,\n        1015,\n        194,\n        867,\n        260,\n        865,\n        749,\n        715,\n        235,\n        478,\n        811,\n        767,\n        74,\n        802,\n        688,\n        570,\n        1090,\n        210,\n        491,\n        622,\n        147,\n        1067,\n        892,\n        645,\n        731,\n        176,\n        205,\n        391,\n        490,\n        1068,\n        23,\n        314,\n        222,\n        598,\n        791,\n        302,\n        550,\n        237,\n        703,\n        744,\n        150,\n        965,\n        273,\n        139,\n        190,\n        733,\n        453,\n        612,\n        611,\n        509,\n        98,\n        166,\n        152,\n        410,\n        837,\n        495,\n        529,\n        983,\n        878,\n        964,\n        164,\n        422,\n        240,\n        907,\n        661,\n        85,\n        701,\n        569,\n        358,\n        406,\n        613,\n        376,\n        785,\n        403,\n        959,\n        465,\n        255,\n        636,\n        247,\n        79,\n        184,\n        954,\n        332,\n        563,\n        782,\n        566,\n        102,\n        567,\n        109,\n        796,\n        541,\n        523,\n        663,\n        37,\n        1029,\n        306,\n        592,\n        291,\n        493,\n        117,\n        1093,\n        683,\n        787,\n        24,\n        183,\n        384,\n        101,\n        675,\n        94,\n        761,\n        45,\n        886,\n        241,\n        146,\n        1096,\n        5,\n        1033,\n        985,\n        544,\n        577,\n        1078,\n        111,\n        1092,\n        871,\n        795,\n        726,\n        860,\n        124,\n        268,\n        982,\n        383,\n        333,\n        28,\n        133,\n        436,\n        1004,\n        435,\n        738,\n        604,\n        276,\n        1017,\n        879,\n        849,\n        18,\n        716,\n        428,\n        430,\n        739,\n        508,\n        953,\n        381,\n        915,\n        979,\n        833,\n        44,\n        651,\n        722,\n        554,\n        880,\n        618,\n        1063,\n        382,\n        1006,\n        182,\n        580,\n        883,\n        389,\n        3,\n        568,\n        130,\n        126,\n        438,\n        336,\n        870,\n        271,\n        369,\n        311,\n        600,\n        398,\n        662,\n        991,\n        359,\n        1023,\n        869,\n        160,\n        323,\n        942,\n        514,\n        527,\n        88,\n        729,\n        827,\n        494,\n        70,\n        766,\n        127,\n        47,\n        1042,\n        56,\n        1,\n        330,\n        484,\n        52,\n        433,\n        992,\n        539,\n        632,\n        903,\n        1043,\n        848,\n        51,\n        409,\n        584,\n        934,\n        499,\n        1021,\n        1079,\n        280,\n        245,\n        799,\n        1086,\n        547,\n        125,\n        1014,\n        226,\n        360,\n        948,\n        1054,\n        553,\n        258,\n        128,\n        818,\n        116,\n        626,\n        54,\n        838,\n        549,\n        537,\n        961,\n        904,\n        902,\n        919,\n        515,\n        175,\n        873,\n        861,\n        492,\n        13,\n        50,\n        590,\n        440,\n        203,\n        524,\n        710,\n        35,\n        46,\n        250,\n        122,\n        293,\n        602,\n        69,\n        232,\n        349,\n        556,\n        295,\n        1091,\n        765,\n        678,\n        771,\n        933,\n        763,\n        33,\n        906,\n        734,\n        986,\n        1062,\n        522,\n        637,\n        488,\n        4,\n        204,\n        1008,\n        423,\n        36,\n        419,\n        581,\n        429,\n        297,\n        426,\n        978,\n        354,\n        1053,\n        885,\n        812,\n        530,\n        1011,\n        431,\n        132,\n        940,\n        353,\n        634,\n        825,\n        1077,\n        657,\n        847,\n        868,\n        348,\n        944,\n        187,\n        588,\n        858,\n        856,\n        826,\n        962,\n        195,\n        542,\n        34,\n        123,\n        805,\n        230,\n        980,\n        461,\n        912\n    ]\n}"
  },
  {
    "path": "datainfo/elastic_pendulum/data_split_dict_1.json",
    "content": "{\n    \"test\": [\n        187,\n        370,\n        362,\n        470,\n        57,\n        938,\n        678,\n        3,\n        708,\n        1137,\n        730,\n        993,\n        846,\n        1033,\n        409,\n        746,\n        985,\n        114,\n        872,\n        420,\n        1062,\n        264,\n        1049,\n        785,\n        11,\n        551,\n        940,\n        723,\n        704,\n        1052,\n        828,\n        475,\n        1105,\n        408,\n        25,\n        464,\n        1028,\n        345,\n        348,\n        806,\n        631,\n        89,\n        961,\n        60,\n        1002,\n        758,\n        1142,\n        1066,\n        335,\n        221,\n        1041,\n        898,\n        177,\n        767,\n        1123,\n        751,\n        354,\n        848,\n        827,\n        497,\n        983,\n        70,\n        805,\n        1034,\n        1022,\n        581,\n        621,\n        388,\n        1039,\n        1171,\n        1025,\n        681,\n        247,\n        607,\n        380,\n        204,\n        1139,\n        852,\n        44,\n        593,\n        941,\n        448,\n        472,\n        707,\n        477,\n        1132,\n        1015,\n        896,\n        454,\n        1080,\n        59,\n        864,\n        443,\n        780,\n        18,\n        1108,\n        52,\n        45,\n        62,\n        650,\n        209,\n        468,\n        545,\n        912,\n        4,\n        886,\n        798,\n        58,\n        999,\n        192,\n        429,\n        777,\n        967,\n        920,\n        1014,\n        241,\n        522,\n        129,\n        1165,\n        275\n    ],\n    \"val\": [\n        411,\n        892,\n        626,\n        333,\n        17,\n        1071,\n        516,\n        303,\n        399,\n        1151,\n        960,\n        106,\n        504,\n        198,\n        605,\n        1172,\n        918,\n        690,\n        587,\n        210,\n        101,\n        355,\n        205,\n        387,\n        1000,\n        521,\n        637,\n        720,\n        1186,\n        890,\n        888,\n        847,\n        175,\n        471,\n        583,\n        922,\n        1096,\n        1046,\n        839,\n        604,\n        38,\n        870,\n        899,\n        574,\n        8,\n        133,\n        258,\n        578,\n        426,\n        162,\n        761,\n        305,\n        1122,\n        939,\n        317,\n        78,\n        879,\n        1036,\n        313,\n        1053,\n        1167,\n        217,\n        991,\n        257,\n        611,\n        120,\n        1093,\n        657,\n        808,\n        1178,\n        457,\n        923,\n        451,\n        873,\n        1183,\n        328,\n        72,\n        299,\n        813,\n        36,\n        461,\n        42,\n        884,\n        428,\n        1044,\n        519,\n        222,\n        529,\n        385,\n        862,\n        703,\n        791,\n        638,\n        48,\n        233,\n        970,\n        1016,\n        659,\n        931,\n        603,\n        558,\n        344,\n        1079,\n        326,\n        342,\n        142,\n        594,\n        705,\n        378,\n        224,\n        550,\n        511,\n        575,\n        29,\n        927,\n        34,\n        170,\n        144,\n        66,\n        1196\n    ],\n    \"train\": [\n        483,\n        489,\n        341,\n        685,\n        96,\n        1,\n        1148,\n        542,\n        656,\n        744,\n        1085,\n        613,\n        855,\n        775,\n        680,\n        111,\n        1081,\n        648,\n        308,\n        733,\n        1023,\n        189,\n        759,\n        61,\n        482,\n        639,\n        1190,\n        228,\n        776,\n        596,\n        885,\n        1024,\n        1187,\n        1048,\n        295,\n        1158,\n        1146,\n        1094,\n        819,\n        684,\n        540,\n        793,\n        856,\n        1099,\n        622,\n        262,\n        773,\n        609,\n        107,\n        263,\n        260,\n        1160,\n        958,\n        474,\n        672,\n        739,\n        442,\n        478,\n        887,\n        1083,\n        734,\n        612,\n        1084,\n        956,\n        843,\n        55,\n        1067,\n        750,\n        823,\n        9,\n        532,\n        112,\n        282,\n        329,\n        520,\n        874,\n        359,\n        1125,\n        1102,\n        514,\n        668,\n        1164,\n        459,\n        197,\n        844,\n        753,\n        635,\n        633,\n        185,\n        1195,\n        966,\n        201,\n        906,\n        1059,\n        569,\n        146,\n        580,\n        39,\n        1047,\n        446,\n        597,\n        276,\n        43,\n        1042,\n        533,\n        561,\n        166,\n        1003,\n        154,\n        778,\n        647,\n        243,\n        67,\n        768,\n        617,\n        1149,\n        634,\n        981,\n        649,\n        450,\n        507,\n        215,\n        1057,\n        740,\n        691,\n        710,\n        1120,\n        858,\n        889,\n        682,\n        71,\n        531,\n        869,\n        415,\n        849,\n        444,\n        237,\n        716,\n        214,\n        764,\n        735,\n        765,\n        595,\n        486,\n        90,\n        834,\n        989,\n        437,\n        292,\n        1127,\n        549,\n        330,\n        16,\n        304,\n        559,\n        1168,\n        104,\n        186,\n        916,\n        719,\n        99,\n        623,\n        476,\n        293,\n        498,\n        286,\n        1136,\n        905,\n        503,\n        143,\n        952,\n        917,\n        167,\n        242,\n        26,\n        797,\n        149,\n        315,\n        652,\n        897,\n        871,\n        234,\n        271,\n        223,\n        1029,\n        769,\n        238,\n        673,\n        924,\n        294,\n        1199,\n        567,\n        33,\n        207,\n        821,\n        945,\n        795,\n        895,\n        548,\n        784,\n        729,\n        987,\n        412,\n        908,\n        986,\n        965,\n        375,\n        427,\n        6,\n        891,\n        456,\n        195,\n        980,\n        641,\n        435,\n        419,\n        1004,\n        925,\n        721,\n        538,\n        80,\n        718,\n        762,\n        679,\n        1063,\n        131,\n        937,\n        311,\n        307,\n        552,\n        485,\n        1054,\n        1043,\n        959,\n        1176,\n        840,\n        1037,\n        432,\n        699,\n        1177,\n        812,\n        424,\n        1157,\n        752,\n        360,\n        190,\n        68,\n        316,\n        693,\n        1194,\n        383,\n        384,\n        951,\n        517,\n        741,\n        1110,\n        1072,\n        671,\n        676,\n        219,\n        179,\n        1121,\n        30,\n        396,\n        1106,\n        1005,\n        585,\n        402,\n        321,\n        954,\n        943,\n        606,\n        749,\n        255,\n        859,\n        530,\n        865,\n        390,\n        395,\n        976,\n        877,\n        556,\n        268,\n        1007,\n        755,\n        957,\n        909,\n        702,\n        933,\n        509,\n        929,\n        130,\n        582,\n        289,\n        1073,\n        277,\n        756,\n        973,\n        1100,\n        394,\n        525,\n        81,\n        598,\n        100,\n        147,\n        379,\n        196,\n        644,\n        851,\n        2,\n        782,\n        351,\n        824,\n        1090,\n        56,\n        995,\n        586,\n        636,\n        919,\n        660,\n        108,\n        199,\n        1045,\n        280,\n        1091,\n        1107,\n        553,\n        757,\n        421,\n        193,\n        655,\n        653,\n        119,\n        494,\n        492,\n        1019,\n        675,\n        343,\n        1134,\n        140,\n        339,\n        692,\n        453,\n        1018,\n        438,\n        1163,\n        964,\n        571,\n        695,\n        640,\n        28,\n        696,\n        95,\n        51,\n        510,\n        164,\n        493,\n        910,\n        54,\n        1188,\n        156,\n        911,\n        998,\n        1170,\n        123,\n        838,\n        158,\n        184,\n        760,\n        401,\n        371,\n        1141,\n        974,\n        269,\n        235,\n        134,\n        770,\n        75,\n        372,\n        747,\n        227,\n        296,\n        1032,\n        374,\n        1173,\n        487,\n        97,\n        434,\n        127,\n        893,\n        1129,\n        841,\n        206,\n        591,\n        15,\n        338,\n        125,\n        361,\n        619,\n        694,\n        113,\n        366,\n        599,\n        413,\n        469,\n        404,\n        842,\n        88,\n        1155,\n        787,\n        200,\n        1092,\n        1068,\n        738,\n        714,\n        463,\n        630,\n        543,\n        913,\n        334,\n        152,\n        481,\n        1114,\n        645,\n        285,\n        462,\n        854,\n        188,\n        337,\n        267,\n        425,\n        357,\n        984,\n        717,\n        137,\n        1031,\n        121,\n        947,\n        270,\n        1131,\n        473,\n        484,\n        1086,\n        1013,\n        836,\n        570,\n        47,\n        353,\n        82,\n        1088,\n        832,\n        84,\n        501,\n        226,\n        336,\n        1075,\n        666,\n        35,\n        216,\n        781,\n        508,\n        50,\n        465,\n        491,\n        23,\n        763,\n        669,\n        926,\n        674,\n        132,\n        687,\n        236,\n        77,\n        155,\n        278,\n        350,\n        1058,\n        715,\n        1082,\n        534,\n        135,\n        467,\n        178,\n        5,\n        309,\n        397,\n        417,\n        37,\n        754,\n        670,\n        433,\n        902,\n        815,\n        524,\n        479,\n        544,\n        664,\n        79,\n        701,\n        789,\n        528,\n        748,\n        829,\n        148,\n        332,\n        495,\n        904,\n        539,\n        381,\n        745,\n        790,\n        69,\n        1061,\n        414,\n        624,\n        398,\n        727,\n        368,\n        202,\n        537,\n        174,\n        535,\n        590,\n        1162,\n        1051,\n        643,\n        978,\n        1087,\n        642,\n        602,\n        915,\n        115,\n        972,\n        779,\n        1038,\n        731,\n        358,\n        816,\n        452,\n        850,\n        0,\n        331,\n        1145,\n        833,\n        515,\n        447,\n        568,\n        382,\n        663,\n        1135,\n        139,\n        1192,\n        13,\n        589,\n        689,\n        867,\n        290,\n        1180,\n        392,\n        213,\n        124,\n        430,\n        502,\n        365,\n        722,\n        172,\n        935,\n        651,\n        318,\n        279,\n        955,\n        950,\n        151,\n        291,\n        736,\n        724,\n        266,\n        248,\n        31,\n        21,\n        883,\n        665,\n        194,\n        1184,\n        572,\n        837,\n        254,\n        284,\n        820,\n        1150,\n        418,\n        546,\n        441,\n        122,\n        1021,\n        499,\n        725,\n        662,\n        963,\n        809,\n        406,\n        712,\n        391,\n        246,\n        455,\n        1124,\n        1065,\n        85,\n        875,\n        928,\n        914,\n        783,\n        536,\n        1074,\n        766,\n        168,\n        527,\n        393,\n        356,\n        403,\n        1027,\n        53,\n        994,\n        997,\n        436,\n        700,\n        159,\n        180,\n        386,\n        860,\n        831,\n        620,\n        126,\n        87,\n        608,\n        1111,\n        169,\n        1154,\n        565,\n        713,\n        1011,\n        319,\n        208,\n        646,\n        163,\n        377,\n        523,\n        73,\n        709,\n        1069,\n        281,\n        625,\n        573,\n        1117,\n        422,\n        230,\n        490,\n        506,\n        814,\n        244,\n        231,\n        654,\n        1161,\n        512,\n        772,\n        541,\n        181,\n        7,\n        944,\n        239,\n        932,\n        627,\n        176,\n        410,\n        1098,\n        407,\n        868,\n        802,\n        737,\n        907,\n        1153,\n        141,\n        901,\n        363,\n        63,\n        19,\n        880,\n        1008,\n        661,\n        24,\n        1156,\n        1101,\n        992,\n        1143,\n        405,\n        1104,\n        1115,\n        711,\n        449,\n        794,\n        1119,\n        562,\n        440,\n        1030,\n        698,\n        811,\n        1006,\n        253,\n        683,\n        49,\n        161,\n        1070,\n        975,\n        306,\n        182,\n        979,\n        706,\n        566,\n        688,\n        1109,\n        323,\n        32,\n        1060,\n        145,\n        211,\n        1197,\n        300,\n        616,\n        526,\n        726,\n        109,\n        312,\n        817,\n        1077,\n        153,\n        183,\n        1198,\n        969,\n        996,\n        881,\n        774,\n        513,\n        1133,\n        157,\n        799,\n        505,\n        367,\n        297,\n        822,\n        1179,\n        22,\n        76,\n        1095,\n        1017,\n        900,\n        1193,\n        894,\n        1055,\n        249,\n        20,\n        225,\n        250,\n        94,\n        818,\n        1103,\n        962,\n        557,\n        103,\n        1064,\n        251,\n        310,\n        592,\n        810,\n        191,\n        953,\n        1130,\n        792,\n        968,\n        232,\n        948,\n        658,\n        588,\n        826,\n        92,\n        458,\n        771,\n        91,\n        287,\n        876,\n        369,\n        252,\n        203,\n        314,\n        1138,\n        554,\n        1169,\n        265,\n        364,\n        677,\n        480,\n        1175,\n        835,\n        796,\n        632,\n        803,\n        220,\n        256,\n        1097,\n        466,\n        615,\n        324,\n        65,\n        64,\n        1113,\n        320,\n        400,\n        460,\n        327,\n        610,\n        743,\n        863,\n        866,\n        10,\n        27,\n        853,\n        325,\n        667,\n        212,\n        102,\n        322,\n        488,\n        728,\n        259,\n        1056,\n        301,\n        555,\n        825,\n        882,\n        445,\n        105,\n        1010,\n        1009,\n        1152,\n        697,\n        171,\n        1040,\n        118,\n        165,\n        431,\n        600,\n        804,\n        245,\n        1189,\n        1020,\n        1116,\n        845,\n        1001,\n        423,\n        93,\n        14,\n        686,\n        628,\n        12,\n        1026,\n        1144,\n        46,\n        1126,\n        110,\n        283,\n        1181,\n        1012,\n        934,\n        577,\n        302,\n        373,\n        273,\n        83,\n        579,\n        229,\n        563,\n        584,\n        1166,\n        1140,\n        800,\n        601,\n        629,\n        117,\n        349,\n        128,\n        1089,\n        150,\n        807,\n        1185,\n        389,\n        74,\n        416,\n        40,\n        1035,\n        971,\n        788,\n        564,\n        1159,\n        949,\n        500,\n        732,\n        988,\n        618,\n        990,\n        930,\n        298,\n        116,\n        1118,\n        346,\n        376,\n        261,\n        861,\n        518,\n        614,\n        340,\n        1191,\n        274,\n        946,\n        1112,\n        1076,\n        173,\n        136,\n        86,\n        41,\n        742,\n        1078,\n        240,\n        1182,\n        786,\n        496,\n        547,\n        1050,\n        942,\n        903,\n        936,\n        352,\n        560,\n        1147,\n        857,\n        98,\n        977,\n        272,\n        218,\n        439,\n        347,\n        138,\n        801,\n        576,\n        830,\n        1128,\n        878,\n        982,\n        160,\n        1174,\n        288,\n        921\n    ]\n}"
  },
  {
    "path": "datainfo/elastic_pendulum/data_split_dict_2.json",
    "content": "{\n    \"test\": [\n        789,\n        3,\n        1071,\n        376,\n        321,\n        261,\n        523,\n        764,\n        43,\n        83,\n        51,\n        138,\n        235,\n        169,\n        1167,\n        510,\n        352,\n        737,\n        742,\n        116,\n        65,\n        866,\n        123,\n        431,\n        501,\n        544,\n        1163,\n        1069,\n        1113,\n        464,\n        559,\n        100,\n        120,\n        217,\n        391,\n        17,\n        699,\n        154,\n        750,\n        1048,\n        1001,\n        425,\n        638,\n        832,\n        1039,\n        1060,\n        1032,\n        621,\n        633,\n        982,\n        1181,\n        340,\n        664,\n        454,\n        996,\n        935,\n        718,\n        1171,\n        931,\n        1152,\n        1055,\n        1025,\n        1020,\n        571,\n        1003,\n        511,\n        1086,\n        944,\n        818,\n        330,\n        1156,\n        741,\n        724,\n        1178,\n        1075,\n        849,\n        912,\n        372,\n        1146,\n        1052,\n        736,\n        1162,\n        1044,\n        279,\n        355,\n        665,\n        361,\n        48,\n        472,\n        483,\n        363,\n        1147,\n        336,\n        1076,\n        867,\n        778,\n        652,\n        952,\n        745,\n        56,\n        1191,\n        549,\n        1028,\n        911,\n        1114,\n        761,\n        1042,\n        805,\n        882,\n        324,\n        1190,\n        73,\n        434,\n        515,\n        631,\n        346,\n        739,\n        173,\n        187,\n        115\n    ],\n    \"val\": [\n        954,\n        706,\n        296,\n        375,\n        625,\n        121,\n        943,\n        531,\n        19,\n        374,\n        491,\n        821,\n        679,\n        96,\n        942,\n        185,\n        983,\n        537,\n        951,\n        428,\n        902,\n        52,\n        605,\n        595,\n        21,\n        1158,\n        435,\n        444,\n        825,\n        746,\n        215,\n        986,\n        1175,\n        1037,\n        701,\n        1108,\n        389,\n        657,\n        548,\n        317,\n        1109,\n        475,\n        685,\n        533,\n        25,\n        222,\n        107,\n        237,\n        768,\n        186,\n        20,\n        102,\n        104,\n        246,\n        89,\n        1151,\n        524,\n        113,\n        1029,\n        418,\n        393,\n        1046,\n        1117,\n        9,\n        570,\n        1101,\n        525,\n        1160,\n        467,\n        164,\n        513,\n        150,\n        910,\n        476,\n        505,\n        64,\n        474,\n        928,\n        196,\n        349,\n        1149,\n        269,\n        68,\n        518,\n        1099,\n        287,\n        36,\n        859,\n        536,\n        530,\n        698,\n        294,\n        671,\n        997,\n        804,\n        1085,\n        917,\n        49,\n        208,\n        647,\n        191,\n        461,\n        968,\n        314,\n        822,\n        540,\n        93,\n        918,\n        1194,\n        63,\n        1000,\n        690,\n        585,\n        231,\n        704,\n        8,\n        74,\n        310,\n        507,\n        88\n    ],\n    \"train\": [\n        689,\n        907,\n        403,\n        285,\n        1073,\n        796,\n        316,\n        140,\n        1183,\n        132,\n        861,\n        1047,\n        167,\n        556,\n        985,\n        999,\n        1185,\n        891,\n        984,\n        199,\n        717,\n        40,\n        937,\n        354,\n        405,\n        271,\n        1155,\n        1125,\n        262,\n        900,\n        292,\n        304,\n        151,\n        909,\n        397,\n        580,\n        820,\n        646,\n        667,\n        1057,\n        183,\n        684,\n        760,\n        450,\n        1159,\n        726,\n        198,\n        747,\n        880,\n        1012,\n        335,\n        719,\n        39,\n        913,\n        84,\n        168,\n        207,\n        1051,\n        396,\n        618,\n        1088,\n        834,\n        567,\n        146,\n        923,\n        814,\n        921,\n        629,\n        299,\n        1139,\n        45,\n        734,\n        216,\n        758,\n        302,\n        424,\n        128,\n        60,\n        406,\n        85,\n        447,\n        205,\n        710,\n        547,\n        566,\n        32,\n        156,\n        18,\n        892,\n        105,\n        1170,\n        239,\n        659,\n        1061,\n        797,\n        76,\n        248,\n        541,\n        641,\n        783,\n        969,\n        1199,\n        508,\n        94,\n        58,\n        394,\n        639,\n        863,\n        479,\n        819,\n        1034,\n        487,\n        947,\n        862,\n        283,\n        305,\n        327,\n        438,\n        744,\n        975,\n        386,\n        723,\n        202,\n        756,\n        979,\n        236,\n        1092,\n        1094,\n        774,\n        343,\n        219,\n        920,\n        360,\n        945,\n        601,\n        810,\n        998,\n        568,\n        932,\n        197,\n        603,\n        830,\n        527,\n        808,\n        940,\n        46,\n        792,\n        875,\n        816,\n        1005,\n        506,\n        916,\n        634,\n        976,\n        244,\n        1009,\n        2,\n        1198,\n        290,\n        766,\n        1182,\n        840,\n        417,\n        182,\n        249,\n        210,\n        443,\n        623,\n        307,\n        590,\n        1043,\n        592,\n        925,\n        155,\n        858,\n        675,\n        90,\n        427,\n        442,\n        848,\n        846,\n        1093,\n        692,\n        1030,\n        1131,\n        35,\n        255,\n        111,\n        1166,\n        469,\n        1016,\n        22,\n        614,\n        377,\n        1018,\n        26,\n        195,\n        499,\n        165,\n        37,\n        333,\n        883,\n        899,\n        247,\n        906,\n        981,\n        288,\n        365,\n        616,\n        693,\n        1070,\n        7,\n        1135,\n        441,\n        1179,\n        874,\n        188,\n        992,\n        851,\n        347,\n        1054,\n        456,\n        258,\n        991,\n        529,\n        1110,\n        1066,\n        970,\n        69,\n        370,\n        457,\n        735,\n        1137,\n        546,\n        1050,\n        13,\n        879,\n        312,\n        348,\n        872,\n        23,\n        226,\n        611,\n        1130,\n        678,\n        445,\n        57,\n        557,\n        569,\n        853,\n        668,\n        532,\n        272,\n        806,\n        379,\n        977,\n        974,\n        223,\n        266,\n        385,\n        127,\n        971,\n        31,\n        562,\n        1112,\n        1058,\n        624,\n        753,\n        798,\n        1161,\n        300,\n        281,\n        980,\n        552,\n        253,\n        42,\n        142,\n        228,\n        486,\n        573,\n        436,\n        1103,\n        407,\n        459,\n        781,\n        1023,\n        1142,\n        780,\n        1077,\n        225,\n        175,\n        517,\n        714,\n        163,\n        29,\n        1157,\n        622,\n        14,\n        6,\n        961,\n        473,\n        274,\n        660,\n        157,\n        329,\n        295,\n        1017,\n        241,\n        617,\n        1065,\n        1021,\n        158,\n        731,\n        828,\n        811,\n        1064,\n        227,\n        378,\n        888,\n        565,\n        1063,\n        586,\n        1116,\n        411,\n        1098,\n        707,\n        232,\n        229,\n        87,\n        1100,\n        423,\n        833,\n        1041,\n        838,\n        1134,\n        612,\n        133,\n        896,\n        214,\n        109,\n        578,\n        946,\n        887,\n        645,\n        398,\n        1121,\n        539,\n        757,\n        670,\n        212,\n        211,\n        687,\n        572,\n        865,\n        1196,\n        700,\n        80,\n        584,\n        149,\n        588,\n        662,\n        351,\n        1136,\n        972,\n        542,\n        708,\n        325,\n        521,\n        455,\n        1111,\n        1096,\n        957,\n        432,\n        306,\n        786,\n        655,\n        1033,\n        458,\n        591,\n        850,\n        1193,\n        134,\n        380,\n        1174,\n        1148,\n        1144,\n        1040,\n        1010,\n        962,\n        200,\n        99,\n        683,\n        876,\n        384,\n        282,\n        308,\n        482,\n        1192,\n        1123,\n        857,\n        357,\n        894,\n        583,\n        669,\n        117,\n        615,\n        250,\n        785,\n        466,\n        564,\n        139,\n        575,\n        1090,\n        607,\n        50,\n        415,\n        1053,\n        1097,\n        868,\n        950,\n        66,\n        498,\n        1180,\n        1118,\n        174,\n        152,\n        953,\n        193,\n        967,\n        898,\n        1084,\n        77,\n        1019,\n        847,\n        53,\n        705,\n        776,\n        345,\n        649,\n        1056,\n        268,\n        581,\n        201,\n        490,\n        331,\n        318,\n        635,\n        619,\n        672,\n        519,\n        598,\n        27,\n        632,\n        108,\n        492,\n        534,\n        126,\n        59,\n        276,\n        812,\n        0,\n        251,\n        855,\n        484,\n        122,\n        448,\n        270,\n        1124,\n        596,\n        267,\n        416,\n        941,\n        835,\n        949,\n        964,\n        749,\n        930,\n        800,\n        277,\n        864,\n        356,\n        70,\n        15,\n        286,\n        71,\n        1013,\n        33,\n        1102,\n        1127,\n        460,\n        439,\n        897,\n        613,\n        332,\n        419,\n        172,\n        784,\n        504,\n        703,\n        933,\n        12,\n        844,\n        112,\n        606,\n        772,\n        948,\n        129,\n        4,\n        593,\n        408,\n        234,\n        993,\n        666,\n        637,\n        179,\n        512,\n        965,\n        839,\n        966,\n        110,\n        676,\n        463,\n        751,\n        421,\n        815,\n        465,\n        410,\n        1078,\n        794,\n        92,\n        1189,\n        1015,\n        485,\n        1067,\n        955,\n        233,\n        597,\n        1195,\n        119,\n        1059,\n        576,\n        135,\n        278,\n        1133,\n        440,\n        1035,\n        759,\n        147,\n        1107,\n        929,\n        344,\n        47,\n        245,\n        252,\n        914,\n        608,\n        654,\n        256,\n        369,\n        430,\n        885,\n        257,\n        339,\n        769,\n        145,\n        97,\n        680,\n        206,\n        323,\n        807,\n        1022,\n        752,\n        1138,\n        554,\n        787,\n        242,\n        604,\n        555,\n        738,\n        446,\n        721,\n        809,\n        171,\n        359,\n        106,\n        711,\n        130,\n        642,\n        322,\n        190,\n        1014,\n        488,\n        802,\n        362,\n        673,\n        178,\n        677,\n        904,\n        653,\n        1083,\n        788,\n        1105,\n        630,\n        799,\n        328,\n        775,\n        713,\n        1045,\n        350,\n        409,\n        387,\n        563,\n        1188,\n        651,\n        702,\n        915,\n        1122,\n        189,\n        919,\n        194,\n        480,\n        243,\n        101,\n        162,\n        334,\n        1106,\n        755,\n        905,\n        137,\n        813,\n        298,\n        452,\n        889,\n        170,\n        371,\n        767,\n        716,\n        79,\n        852,\n        326,\n        218,\n        373,\n        388,\n        644,\n        827,\n        62,\n        284,\n        535,\n        836,\n        574,\n        886,\n        238,\n        41,\n        681,\n        91,\n        643,\n        516,\n        153,\n        543,\n        1141,\n        901,\n        990,\n        648,\n        520,\n        95,\n        1049,\n        893,\n        658,\n        319,\n        661,\n        341,\n        311,\n        765,\n        264,\n        926,\n        309,\n        551,\n        881,\n        136,\n        1132,\n        1095,\n        958,\n        762,\n        453,\n        577,\n        1008,\n        550,\n        34,\n        963,\n        1,\n        729,\n        1024,\n        289,\n        10,\n        826,\n        301,\n        1074,\n        1087,\n        390,\n        1145,\n        166,\n        55,\n        1091,\n        640,\n        691,\n        1140,\n        790,\n        779,\n        1164,\n        934,\n        81,\n        28,\n        471,\n        514,\n        273,\n        470,\n        903,\n        1082,\n        496,\n        1129,\n        67,\n        777,\n        1002,\n        824,\n        939,\n        1081,\n        358,\n        1173,\n        209,\n        478,\n        177,\n        579,\n        1026,\n        1080,\n        493,\n        987,\n        694,\n        627,\n        1154,\n        141,\n        626,\n        1104,\n        922,\n        594,\n        600,\n        791,\n        204,\n        1143,\n        873,\n        1011,\n        220,\n        395,\n        620,\n        582,\n        1187,\n        221,\n        770,\n        1115,\n        973,\n        181,\n        1165,\n        803,\n        728,\n        842,\n        1120,\n        320,\n        11,\n        650,\n        1184,\n        682,\n        412,\n        192,\n        636,\n        545,\n        230,\n        743,\n        1089,\n        956,\n        763,\n        730,\n        338,\n        560,\n        368,\n        1169,\n        841,\n        890,\n        960,\n        748,\n        727,\n        259,\n        414,\n        280,\n        494,\n        54,\n        522,\n        402,\n        184,\n        754,\n        392,\n        609,\n        30,\n        878,\n        451,\n        477,\n        148,\n        989,\n        587,\n        782,\n        801,\n        367,\n        413,\n        240,\n        720,\n        1168,\n        161,\n        1197,\n        263,\n        526,\n        732,\n        254,\n        528,\n        1172,\n        854,\n        433,\n        924,\n        553,\n        159,\n        114,\n        293,\n        1119,\n        176,\n        449,\n        337,\n        843,\n        686,\n        725,\n        869,\n        399,\n        871,\n        131,\n        38,\n        663,\n        829,\n        697,\n        1079,\n        884,\n        1186,\n        426,\n        988,\n        1031,\n        558,\n        180,\n        364,\n        72,\n        98,\n        589,\n        837,\n        877,\n        599,\n        86,\n        715,\n        695,\n        712,\n        437,\n        561,\n        265,\n        610,\n        500,\n        160,\n        1007,\n        1150,\n        303,\n        696,\n        856,\n        503,\n        429,\n        823,\n        1153,\n        1027,\n        489,\n        538,\n        495,\n        275,\n        383,\n        817,\n        144,\n        468,\n        366,\n        297,\n        509,\n        722,\n        927,\n        44,\n        795,\n        481,\n        674,\n        1128,\n        1004,\n        1177,\n        709,\n        995,\n        400,\n        656,\n        342,\n        870,\n        1068,\n        82,\n        125,\n        936,\n        260,\n        124,\n        908,\n        404,\n        353,\n        771,\n        143,\n        938,\n        733,\n        401,\n        382,\n        1072,\n        118,\n        1038,\n        502,\n        773,\n        224,\n        78,\n        740,\n        978,\n        959,\n        24,\n        291,\n        5,\n        75,\n        628,\n        602,\n        1126,\n        213,\n        1036,\n        497,\n        1176,\n        420,\n        61,\n        462,\n        831,\n        16,\n        845,\n        688,\n        793,\n        860,\n        203,\n        313,\n        1006,\n        103,\n        422,\n        994,\n        895,\n        1062,\n        315,\n        381\n    ]\n}"
  },
  {
    "path": "datainfo/elastic_pendulum/data_split_dict_3.json",
    "content": "{\n    \"test\": [\n        1194,\n        644,\n        1184,\n        23,\n        694,\n        620,\n        1067,\n        1157,\n        895,\n        1155,\n        486,\n        883,\n        555,\n        1153,\n        210,\n        1152,\n        1065,\n        942,\n        771,\n        1121,\n        283,\n        737,\n        642,\n        695,\n        86,\n        319,\n        539,\n        597,\n        835,\n        404,\n        64,\n        1096,\n        221,\n        157,\n        14,\n        634,\n        1161,\n        483,\n        1035,\n        571,\n        677,\n        773,\n        92,\n        90,\n        243,\n        850,\n        1167,\n        601,\n        41,\n        1181,\n        840,\n        136,\n        704,\n        181,\n        990,\n        987,\n        129,\n        254,\n        583,\n        546,\n        432,\n        213,\n        1109,\n        668,\n        334,\n        572,\n        58,\n        689,\n        475,\n        834,\n        1093,\n        718,\n        790,\n        1038,\n        862,\n        1172,\n        893,\n        528,\n        444,\n        1013,\n        278,\n        73,\n        199,\n        748,\n        274,\n        910,\n        808,\n        874,\n        793,\n        968,\n        551,\n        63,\n        616,\n        87,\n        326,\n        131,\n        31,\n        798,\n        1071,\n        310,\n        474,\n        308,\n        813,\n        975,\n        1125,\n        1107,\n        963,\n        392,\n        479,\n        1128,\n        531,\n        960,\n        26,\n        134,\n        1189,\n        970,\n        757,\n        267,\n        1114,\n        487\n    ],\n    \"val\": [\n        710,\n        822,\n        925,\n        35,\n        46,\n        250,\n        829,\n        122,\n        293,\n        602,\n        878,\n        1029,\n        882,\n        232,\n        911,\n        349,\n        556,\n        295,\n        1190,\n        765,\n        678,\n        1079,\n        856,\n        467,\n        763,\n        33,\n        227,\n        734,\n        949,\n        972,\n        1145,\n        1158,\n        522,\n        637,\n        901,\n        852,\n        965,\n        1047,\n        4,\n        204,\n        159,\n        423,\n        1097,\n        36,\n        419,\n        581,\n        429,\n        297,\n        426,\n        649,\n        354,\n        1148,\n        277,\n        870,\n        812,\n        530,\n        298,\n        431,\n        132,\n        603,\n        353,\n        1051,\n        825,\n        175,\n        696,\n        570,\n        375,\n        645,\n        390,\n        69,\n        247,\n        460,\n        554,\n        923,\n        446,\n        1147,\n        163,\n        346,\n        897,\n        459,\n        683,\n        659,\n        208,\n        198,\n        891,\n        383,\n        671,\n        488,\n        1170,\n        455,\n        1024,\n        1115,\n        269,\n        55,\n        214,\n        772,\n        615,\n        540,\n        1086,\n        640,\n        379,\n        745,\n        363,\n        655,\n        611,\n        934,\n        514,\n        756,\n        43,\n        124,\n        45,\n        1002,\n        1119,\n        1074,\n        722,\n        954,\n        680,\n        123,\n        272,\n        1098\n    ],\n    \"train\": [\n        464,\n        1062,\n        890,\n        22,\n        660,\n        427,\n        290,\n        167,\n        469,\n        1159,\n        341,\n        712,\n        560,\n        457,\n        77,\n        952,\n        639,\n        462,\n        1185,\n        1195,\n        654,\n        135,\n        83,\n        797,\n        730,\n        372,\n        650,\n        218,\n        225,\n        787,\n        285,\n        265,\n        343,\n        279,\n        452,\n        1028,\n        397,\n        951,\n        948,\n        421,\n        855,\n        625,\n        158,\n        846,\n        1004,\n        1009,\n        753,\n        72,\n        466,\n        559,\n        914,\n        80,\n        591,\n        738,\n        142,\n        1120,\n        847,\n        408,\n        708,\n        973,\n        552,\n        236,\n        364,\n        1186,\n        207,\n        606,\n        1103,\n        338,\n        741,\n        622,\n        282,\n        178,\n        991,\n        263,\n        1054,\n        1112,\n        612,\n        586,\n        1072,\n        1193,\n        371,\n        1129,\n        480,\n        66,\n        1176,\n        781,\n        691,\n        507,\n        735,\n        853,\n        91,\n        575,\n        386,\n        192,\n        928,\n        67,\n        335,\n        785,\n        656,\n        585,\n        220,\n        789,\n        374,\n        347,\n        706,\n        231,\n        196,\n        188,\n        752,\n        681,\n        348,\n        324,\n        1177,\n        309,\n        151,\n        727,\n        929,\n        525,\n        688,\n        100,\n        95,\n        1141,\n        1124,\n        352,\n        519,\n        1090,\n        624,\n        1075,\n        623,\n        266,\n        351,\n        219,\n        723,\n        1133,\n        470,\n        1085,\n        75,\n        38,\n        1150,\n        510,\n        1031,\n        998,\n        800,\n        8,\n        641,\n        676,\n        370,\n        104,\n        288,\n        766,\n        415,\n        286,\n        292,\n        521,\n        411,\n        42,\n        1182,\n        443,\n        1110,\n        1092,\n        811,\n        873,\n        1073,\n        731,\n        842,\n        977,\n        350,\n        561,\n        788,\n        424,\n        1179,\n        312,\n        876,\n        906,\n        141,\n        321,\n        228,\n        564,\n        294,\n        791,\n        414,\n        922,\n        76,\n        118,\n        328,\n        138,\n        1166,\n        1156,\n        356,\n        628,\n        714,\n        841,\n        777,\n        1142,\n        799,\n        803,\n        103,\n        875,\n        844,\n        638,\n        1063,\n        1055,\n        726,\n        953,\n        1113,\n        313,\n        518,\n        315,\n        449,\n        935,\n        1160,\n        65,\n        881,\n        482,\n        6,\n        633,\n        784,\n        296,\n        498,\n        10,\n        1081,\n        337,\n        331,\n        865,\n        867,\n        434,\n        1169,\n        380,\n        605,\n        394,\n        154,\n        889,\n        884,\n        393,\n        412,\n        143,\n        217,\n        500,\n        587,\n        783,\n        815,\n        1052,\n        945,\n        535,\n        262,\n        961,\n        1007,\n        635,\n        12,\n        284,\n        61,\n        1089,\n        767,\n        672,\n        1076,\n        1197,\n        607,\n        717,\n        1117,\n        327,\n        999,\n        866,\n        1149,\n        0,\n        399,\n        1010,\n        345,\n        230,\n        994,\n        667,\n        121,\n        1036,\n        49,\n        1144,\n        1199,\n        1101,\n        992,\n        489,\n        451,\n        249,\n        533,\n        979,\n        927,\n        137,\n        666,\n        595,\n        496,\n        770,\n        959,\n        212,\n        534,\n        1099,\n        978,\n        1100,\n        197,\n        89,\n        697,\n        686,\n        171,\n        913,\n        1017,\n        477,\n        1146,\n        211,\n        558,\n        827,\n        1108,\n        201,\n        156,\n        619,\n        1039,\n        179,\n        915,\n        1088,\n        679,\n        437,\n        169,\n        657,\n        96,\n        27,\n        246,\n        502,\n        627,\n        684,\n        473,\n        647,\n        1005,\n        53,\n        739,\n        1046,\n        303,\n        485,\n        851,\n        29,\n        955,\n        373,\n        576,\n        699,\n        172,\n        15,\n        1131,\n        287,\n        930,\n        725,\n        589,\n        238,\n        653,\n        148,\n        454,\n        304,\n        1171,\n        511,\n        59,\n        318,\n        170,\n        84,\n        97,\n        857,\n        1162,\n        48,\n        947,\n        112,\n        481,\n        252,\n        703,\n        941,\n        1026,\n        617,\n        981,\n        837,\n        1021,\n        16,\n        1042,\n        162,\n        664,\n        838,\n        760,\n        517,\n        1059,\n        251,\n        1033,\n        420,\n        1154,\n        886,\n        71,\n        1198,\n        325,\n        1104,\n        257,\n        1015,\n        299,\n        830,\n        1084,\n        614,\n        848,\n        506,\n        721,\n        966,\n        849,\n        715,\n        596,\n        1018,\n        107,\n        888,\n        759,\n        418,\n        305,\n        669,\n        504,\n        1014,\n        608,\n        1027,\n        536,\n        368,\n        19,\n        233,\n        1122,\n        224,\n        796,\n        1020,\n        447,\n        234,\n        698,\n        11,\n        545,\n        956,\n        750,\n        106,\n        216,\n        907,\n        34,\n        1135,\n        242,\n        899,\n        200,\n        82,\n        818,\n        450,\n        1140,\n        476,\n        209,\n        21,\n        402,\n        1053,\n        161,\n        248,\n        113,\n        387,\n        357,\n        819,\n        543,\n        833,\n        32,\n        1095,\n        1048,\n        916,\n        1087,\n        40,\n        366,\n        17,\n        396,\n        578,\n        1006,\n        439,\n        1080,\n        68,\n        983,\n        355,\n        682,\n        445,\n        168,\n        165,\n        648,\n        9,\n        898,\n        743,\n        39,\n        259,\n        119,\n        316,\n        1061,\n        950,\n        25,\n        472,\n        574,\n        782,\n        405,\n        513,\n        971,\n        1056,\n        456,\n        858,\n        674,\n        145,\n        289,\n        57,\n        986,\n        786,\n        1044,\n        1180,\n        709,\n        416,\n        821,\n        108,\n        1019,\n        940,\n        693,\n        828,\n        744,\n        646,\n        805,\n        177,\n        503,\n        239,\n        905,\n        2,\n        859,\n        526,\n        824,\n        764,\n        920,\n        317,\n        871,\n        562,\n        417,\n        1049,\n        631,\n        300,\n        378,\n        610,\n        206,\n        588,\n        174,\n        1043,\n        407,\n        149,\n        912,\n        621,\n        1008,\n        193,\n        189,\n        1139,\n        673,\n        565,\n        817,\n        155,\n        114,\n        816,\n        361,\n        806,\n        458,\n        413,\n        340,\n        362,\n        700,\n        253,\n        794,\n        516,\n        685,\n        845,\n        705,\n        832,\n        1187,\n        900,\n        1078,\n        754,\n        367,\n        401,\n        976,\n        755,\n        185,\n        969,\n        1041,\n        180,\n        1105,\n        768,\n        468,\n        652,\n        944,\n        854,\n        974,\n        629,\n        144,\n        573,\n        110,\n        1130,\n        742,\n        1060,\n        775,\n        153,\n        186,\n        724,\n        301,\n        60,\n        924,\n        520,\n        505,\n        936,\n        538,\n        939,\n        579,\n        1037,\n        1058,\n        860,\n        1050,\n        461,\n        377,\n        807,\n        1064,\n        99,\n        256,\n        1057,\n        448,\n        115,\n        442,\n        78,\n        281,\n        917,\n        885,\n        557,\n        932,\n        400,\n        872,\n        701,\n        215,\n        747,\n        497,\n        388,\n        1034,\n        339,\n        342,\n        30,\n        919,\n        344,\n        670,\n        463,\n        365,\n        1003,\n        1016,\n        270,\n        780,\n        120,\n        1173,\n        93,\n        894,\n        187,\n        320,\n        663,\n        173,\n        261,\n        762,\n        636,\n        191,\n        140,\n        264,\n        194,\n        746,\n        260,\n        769,\n        880,\n        728,\n        235,\n        478,\n        964,\n        903,\n        74,\n        988,\n        809,\n        1025,\n        1143,\n        594,\n        1188,\n        1094,\n        736,\n        491,\n        758,\n        147,\n        879,\n        329,\n        1196,\n        861,\n        1138,\n        176,\n        205,\n        391,\n        490,\n        1164,\n        1083,\n        314,\n        222,\n        1132,\n        692,\n        598,\n        609,\n        938,\n        687,\n        302,\n        550,\n        237,\n        826,\n        887,\n        599,\n        150,\n        957,\n        273,\n        139,\n        190,\n        863,\n        453,\n        720,\n        1151,\n        509,\n        98,\n        166,\n        152,\n        410,\n        658,\n        1000,\n        495,\n        702,\n        529,\n        690,\n        593,\n        582,\n        802,\n        425,\n        164,\n        422,\n        240,\n        512,\n        776,\n        85,\n        823,\n        569,\n        358,\n        406,\n        613,\n        376,\n        931,\n        403,\n        716,\n        630,\n        465,\n        255,\n        661,\n        1163,\n        1030,\n        79,\n        184,\n        761,\n        332,\n        563,\n        962,\n        566,\n        102,\n        567,\n        109,\n        943,\n        732,\n        541,\n        523,\n        779,\n        37,\n        1123,\n        306,\n        592,\n        291,\n        713,\n        493,\n        1118,\n        117,\n        1192,\n        804,\n        933,\n        24,\n        183,\n        384,\n        675,\n        101,\n        792,\n        94,\n        896,\n        1070,\n        864,\n        241,\n        146,\n        892,\n        5,\n        995,\n        1127,\n        105,\n        544,\n        577,\n        1174,\n        751,\n        111,\n        1040,\n        385,\n        542,\n        1178,\n        749,\n        982,\n        707,\n        1011,\n        1069,\n        268,\n        1134,\n        1045,\n        333,\n        28,\n        133,\n        436,\n        229,\n        435,\n        868,\n        604,\n        711,\n        276,\n        548,\n        223,\n        996,\n        18,\n        909,\n        428,\n        1191,\n        430,\n        869,\n        719,\n        778,\n        508,\n        1102,\n        381,\n        937,\n        441,\n        993,\n        44,\n        820,\n        651,\n        1082,\n        1032,\n        926,\n        665,\n        618,\n        471,\n        382,\n        1068,\n        182,\n        580,\n        81,\n        814,\n        389,\n        740,\n        733,\n        3,\n        568,\n        130,\n        126,\n        438,\n        336,\n        195,\n        271,\n        369,\n        311,\n        600,\n        398,\n        662,\n        395,\n        359,\n        1116,\n        322,\n        160,\n        323,\n        501,\n        1066,\n        527,\n        88,\n        729,\n        1126,\n        985,\n        494,\n        70,\n        902,\n        127,\n        47,\n        1136,\n        1168,\n        56,\n        877,\n        1,\n        839,\n        795,\n        330,\n        484,\n        52,\n        433,\n        532,\n        1106,\n        632,\n        275,\n        1137,\n        1012,\n        774,\n        51,\n        409,\n        1091,\n        584,\n        643,\n        499,\n        7,\n        843,\n        1175,\n        1022,\n        1165,\n        280,\n        810,\n        245,\n        958,\n        967,\n        836,\n        908,\n        547,\n        125,\n        202,\n        226,\n        360,\n        62,\n        801,\n        831,\n        997,\n        553,\n        989,\n        258,\n        128,\n        1183,\n        116,\n        626,\n        1111,\n        54,\n        1001,\n        549,\n        537,\n        20,\n        980,\n        244,\n        307,\n        515,\n        1023,\n        1077,\n        984,\n        492,\n        13,\n        50,\n        590,\n        440,\n        921,\n        904,\n        918,\n        203,\n        946,\n        524\n    ]\n}"
  },
  {
    "path": "datainfo/fire/data_split_dict_1.json",
    "content": "{\n    \"test\": [\n        248,\n        491,\n        35,\n        873,\n        603,\n        402,\n        517,\n        866,\n        511,\n        903,\n        601,\n        290,\n        310,\n        194,\n        686,\n        849,\n        519,\n        951,\n        512,\n        728,\n        968,\n        917,\n        340,\n        760,\n        123,\n        303,\n        880,\n        741,\n        644,\n        190,\n        102,\n        657,\n        569,\n        857,\n        426,\n        960,\n        296,\n        470,\n        989,\n        224,\n        693,\n        236,\n        353,\n        238,\n        566,\n        990,\n        448,\n        994,\n        227,\n        540,\n        980,\n        743,\n        432,\n        221,\n        702,\n        390,\n        902,\n        9,\n        554,\n        665,\n        26,\n        22,\n        31,\n        325,\n        923,\n        104,\n        605,\n        234,\n        995,\n        738,\n        272,\n        456,\n        712,\n        2,\n        785,\n        780,\n        622,\n        443,\n        399,\n        855,\n        914,\n        29,\n        499,\n        96,\n        214,\n        807,\n        388,\n        667,\n        483,\n        460,\n        779,\n        507,\n        120,\n        261,\n        64,\n        782,\n        821,\n        867,\n        582,\n        137\n    ],\n    \"val\": [\n        992,\n        564,\n        93,\n        185,\n        598,\n        563,\n        181,\n        650,\n        235,\n        28,\n        614,\n        469,\n        339,\n        627,\n        805,\n        638,\n        553,\n        551,\n        1,\n        354,\n        895,\n        365,\n        496,\n        423,\n        516,\n        856,\n        567,\n        583,\n        373,\n        492,\n        57,\n        436,\n        210,\n        574,\n        796,\n        531,\n        132,\n        828,\n        524,\n        758,\n        802,\n        392,\n        5,\n        746,\n        623,\n        854,\n        675,\n        275,\n        936,\n        361,\n        591,\n        352,\n        526,\n        414,\n        237,\n        891,\n        552,\n        204,\n        789,\n        12,\n        232,\n        514,\n        172,\n        174,\n        662,\n        403,\n        592,\n        607,\n        629,\n        720,\n        315,\n        44,\n        480,\n        30,\n        750,\n        501,\n        379,\n        904,\n        860,\n        533,\n        167,\n        797,\n        110,\n        520,\n        679,\n        449,\n        88,\n        383,\n        755,\n        690,\n        794,\n        719,\n        561,\n        375,\n        177,\n        680,\n        424,\n        413,\n        816,\n        761\n    ],\n    \"train\": [\n        852,\n        836,\n        280,\n        575,\n        208,\n        609,\n        931,\n        725,\n        47,\n        984,\n        337,\n        581,\n        155,\n        692,\n        862,\n        790,\n        572,\n        241,\n        991,\n        788,\n        913,\n        498,\n        107,\n        767,\n        839,\n        930,\n        300,\n        386,\n        490,\n        920,\n        49,\n        216,\n        886,\n        73,\n        125,\n        565,\n        262,\n        717,\n        596,\n        669,\n        309,\n        803,\n        876,\n        848,\n        580,\n        864,\n        447,\n        484,\n        418,\n        147,\n        556,\n        323,\n        37,\n        586,\n        438,\n        409,\n        52,\n        599,\n        687,\n        878,\n        716,\n        243,\n        131,\n        562,\n        219,\n        95,\n        808,\n        82,\n        696,\n        306,\n        901,\n        144,\n        998,\n        415,\n        548,\n        4,\n        704,\n        707,\n        143,\n        269,\n        879,\n        479,\n        284,\n        888,\n        950,\n        294,\n        949,\n        13,\n        705,\n        515,\n        751,\n        356,\n        121,\n        621,\n        872,\n        894,\n        982,\n        472,\n        477,\n        863,\n        954,\n        295,\n        729,\n        378,\n        641,\n        753,\n        58,\n        798,\n        371,\n        271,\n        973,\n        169,\n        239,\n        207,\n        358,\n        955,\n        600,\n        509,\n        434,\n        585,\n        34,\n        634,\n        433,\n        311,\n        701,\n        475,\n        331,\n        473,\n        615,\n        7,\n        421,\n        266,\n        56,\n        384,\n        183,\n        391,\n        870,\n        200,\n        100,\n        59,\n        494,\n        893,\n        146,\n        685,\n        396,\n        892,\n        316,\n        176,\n        180,\n        834,\n        766,\n        19,\n        806,\n        68,\n        795,\n        453,\n        231,\n        616,\n        53,\n        246,\n        912,\n        709,\n        814,\n        23,\n        336,\n        570,\n        811,\n        481,\n        113,\n        168,\n        762,\n        362,\n        697,\n        730,\n        543,\n        25,\n        289,\n        197,\n        11,\n        987,\n        628,\n        233,\n        734,\n        655,\n        242,\n        276,\n        77,\n        915,\n        698,\n        444,\n        594,\n        67,\n        89,\n        393,\n        154,\n        673,\n        417,\n        299,\n        726,\n        199,\n        510,\n        39,\n        206,\n        530,\n        942,\n        446,\n        482,\n        525,\n        768,\n        635,\n        267,\n        370,\n        979,\n        678,\n        837,\n        966,\n        559,\n        842,\n        568,\n        293,\n        877,\n        109,\n        148,\n        99,\n        381,\n        345,\n        134,\n        544,\n        244,\n        135,\n        590,\n        158,\n        459,\n        368,\n        838,\n        883,\n        366,\n        405,\n        945,\n        425,\n        926,\n        910,\n        865,\n        940,\n        50,\n        343,\n        186,\n        815,\n        784,\n        529,\n        184,\n        642,\n        478,\n        338,\n        254,\n        817,\n        285,\n        677,\n        777,\n        829,\n        550,\n        874,\n        824,\n        981,\n        179,\n        749,\n        226,\n        927,\n        0,\n        633,\n        357,\n        257,\n        223,\n        493,\n        537,\n        410,\n        695,\n        69,\n        971,\n        905,\n        6,\n        660,\n        666,\n        55,\n        700,\n        380,\n        350,\n        631,\n        62,\n        215,\n        534,\n        182,\n        487,\n        653,\n        742,\n        360,\n        875,\n        139,\n        708,\n        451,\n        75,\n        145,\n        419,\n        699,\n        706,\n        124,\n        15,\n        882,\n        474,\n        630,\n        412,\n        97,\n        972,\n        394,\n        497,\n        127,\n        142,\n        151,\n        868,\n        209,\n        613,\n        307,\n        454,\n        61,\n        108,\n        527,\n        318,\n        495,\n        825,\n        166,\n        377,\n        422,\n        906,\n        195,\n        963,\n        947,\n        654,\n        869,\n        781,\n        539,\n        652,\n        51,\n        321,\n        130,\n        953,\n        268,\n        733,\n        278,\n        612,\n        476,\n        196,\n        178,\n        326,\n        201,\n        640,\n        401,\n        959,\n        291,\n        757,\n        648,\n        452,\n        79,\n        455,\n        193,\n        884,\n        658,\n        535,\n        63,\n        43,\n        304,\n        853,\n        840,\n        359,\n        84,\n        928,\n        282,\n        810,\n        916,\n        152,\n        159,\n        964,\n        385,\n        81,\n        188,\n        676,\n        643,\n        745,\n        770,\n        786,\n        437,\n        140,\n        334,\n        372,\n        312,\n        286,\n        776,\n        211,\n        737,\n        727,\n        595,\n        407,\n        253,\n        775,\n        122,\n        115,\n        374,\n        571,\n        898,\n        450,\n        332,\n        844,\n        588,\n        270,\n        90,\n        3,\n        440,\n        119,\n        45,\n        715,\n        944,\n        671,\n        674,\n        649,\n        141,\n        819,\n        793,\n        429,\n        900,\n        683,\n        70,\n        935,\n        813,\n        818,\n        791,\n        956,\n        292,\n        213,\n        330,\n        887,\n        602,\n        899,\n        835,\n        87,\n        651,\n        428,\n        202,\n        841,\n        542,\n        467,\n        435,\n        938,\n        277,\n        624,\n        281,\n        661,\n        896,\n        731,\n        843,\n        32,\n        398,\n        129,\n        126,\n        341,\n        441,\n        820,\n        764,\n        80,\n        787,\n        952,\n        153,\n        506,\n        489,\n        941,\n        382,\n        430,\n        606,\n        464,\n        344,\n        851,\n        763,\n        462,\n        502,\n        161,\n        541,\n        16,\n        549,\n        997,\n        774,\n        503,\n        925,\n        457,\n        943,\n        625,\n        308,\n        847,\n        427,\n        263,\n        363,\n        54,\n        156,\n        937,\n        688,\n        420,\n        523,\n        754,\n        397,\n        546,\n        885,\n        722,\n        486,\n        260,\n        619,\n        538,\n        513,\n        532,\n        157,\n        558,\n        505,\n        367,\n        946,\n        573,\n        934,\n        76,\n        714,\n        647,\n        812,\n        723,\n        249,\n        20,\n        744,\n        250,\n        94,\n        103,\n        342,\n        251,\n        911,\n        632,\n        191,\n        670,\n        42,\n        978,\n        682,\n        859,\n        986,\n        92,\n        458,\n        91,\n        932,\n        735,\n        369,\n        252,\n        203,\n        314,\n        610,\n        957,\n        187,\n        265,\n        364,\n        871,\n        636,\n        220,\n        256,\n        114,\n        466,\n        324,\n        65,\n        993,\n        408,\n        320,\n        400,\n        988,\n        327,\n        908,\n        10,\n        27,\n        597,\n        962,\n        212,\n        929,\n        322,\n        488,\n        756,\n        617,\n        771,\n        555,\n        752,\n        445,\n        105,\n        710,\n        247,\n        974,\n        732,\n        118,\n        165,\n        922,\n        245,\n        759,\n        996,\n        608,\n        822,\n        801,\n        792,\n        858,\n        611,\n        46,\n        881,\n        283,\n        468,\n        975,\n        659,\n        577,\n        302,\n        827,\n        273,\n        83,\n        579,\n        229,\n        804,\n        584,\n        713,\n        739,\n        909,\n        117,\n        349,\n        718,\n        150,\n        389,\n        74,\n        416,\n        40,\n        724,\n        684,\n        800,\n        593,\n        668,\n        500,\n        618,\n        656,\n        298,\n        765,\n        348,\n        346,\n        376,\n        778,\n        740,\n        809,\n        921,\n        274,\n        958,\n        897,\n        170,\n        173,\n        136,\n        86,\n        41,\n        66,\n        240,\n        545,\n        967,\n        547,\n        783,\n        560,\n        985,\n        98,\n        969,\n        218,\n        439,\n        347,\n        138,\n        576,\n        335,\n        939,\n        160,\n        748,\n        288,\n        411,\n        626,\n        333,\n        889,\n        907,\n        823,\n        924,\n        977,\n        681,\n        106,\n        504,\n        198,\n        965,\n        976,\n        587,\n        831,\n        101,\n        355,\n        205,\n        387,\n        703,\n        521,\n        637,\n        175,\n        471,\n        826,\n        222,\n        604,\n        38,\n        832,\n        8,\n        133,\n        258,\n        578,\n        933,\n        162,\n        769,\n        317,\n        78,\n        833,\n        313,\n        48,\n        217,\n        128,\n        305,\n        60,\n        919,\n        646,\n        845,\n        328,\n        589,\n        691,\n        404,\n        961,\n        664,\n        536,\n        228,\n        461,\n        528,\n        711,\n        645,\n        225,\n        557,\n        830,\n        694,\n        518,\n        721,\n        970,\n        164,\n        736,\n        36,\n        149,\n        406,\n        18,\n        230,\n        21,\n        442,\n        620,\n        983,\n        522,\n        747,\n        259,\n        111,\n        264,\n        192,\n        431,\n        351,\n        395,\n        319,\n        24,\n        116,\n        485,\n        508,\n        329,\n        890,\n        465,\n        301,\n        918,\n        663,\n        279,\n        672,\n        861,\n        948,\n        799,\n        163,\n        171,\n        71,\n        297,\n        850,\n        189,\n        639,\n        112,\n        846,\n        255,\n        287,\n        773,\n        772,\n        14,\n        463,\n        17,\n        85,\n        72,\n        689,\n        33\n    ]\n}"
  },
  {
    "path": "datainfo/fire/data_split_dict_2.json",
    "content": "{\n    \"test\": [\n        465,\n        677,\n        921,\n        956,\n        851,\n        527,\n        512,\n        510,\n        285,\n        501,\n        255,\n        543,\n        670,\n        472,\n        756,\n        732,\n        409,\n        772,\n        165,\n        930,\n        879,\n        370,\n        362,\n        607,\n        808,\n        966,\n        781,\n        537,\n        752,\n        424,\n        815,\n        456,\n        915,\n        186,\n        947,\n        690,\n        526,\n        368,\n        938,\n        522,\n        139,\n        177,\n        332,\n        180,\n        24,\n        236,\n        241,\n        181,\n        573,\n        168,\n        538,\n        905,\n        913,\n        433,\n        389,\n        929,\n        326,\n        954,\n        476,\n        372,\n        28,\n        891,\n        979,\n        922,\n        274,\n        514,\n        455,\n        958,\n        557,\n        380,\n        521,\n        880,\n        740,\n        822,\n        402,\n        653,\n        441,\n        162,\n        697,\n        595,\n        36,\n        621,\n        217,\n        620,\n        257,\n        315,\n        874,\n        685,\n        828,\n        753,\n        173,\n        855,\n        369,\n        86,\n        93,\n        57,\n        869,\n        970,\n        883,\n        978\n    ],\n    \"val\": [\n        4,\n        37,\n        155,\n        253,\n        44,\n        603,\n        394,\n        1,\n        708,\n        535,\n        188,\n        927,\n        160,\n        130,\n        261,\n        382,\n        21,\n        746,\n        41,\n        25,\n        69,\n        117,\n        84,\n        943,\n        688,\n        909,\n        176,\n        936,\n        371,\n        58,\n        32,\n        777,\n        734,\n        952,\n        61,\n        215,\n        250,\n        272,\n        939,\n        534,\n        916,\n        849,\n        698,\n        232,\n        605,\n        279,\n        50,\n        668,\n        588,\n        60,\n        108,\n        762,\n        195,\n        887,\n        8,\n        895,\n        349,\n        840,\n        803,\n        77,\n        638,\n        700,\n        375,\n        524,\n        500,\n        212,\n        748,\n        319,\n        416,\n        602,\n        630,\n        667,\n        519,\n        530,\n        575,\n        516,\n        903,\n        723,\n        818,\n        310,\n        316,\n        491,\n        791,\n        963,\n        631,\n        170,\n        990,\n        716,\n        834,\n        941,\n        227,\n        674,\n        498,\n        467,\n        741,\n        570,\n        743,\n        581,\n        359,\n        912\n    ],\n    \"train\": [\n        290,\n        214,\n        112,\n        544,\n        695,\n        541,\n        792,\n        671,\n        658,\n        339,\n        298,\n        644,\n        820,\n        496,\n        92,\n        289,\n        206,\n        211,\n        997,\n        218,\n        856,\n        847,\n        988,\n        62,\n        844,\n        22,\n        559,\n        507,\n        266,\n        643,\n        793,\n        870,\n        866,\n        542,\n        432,\n        300,\n        471,\n        647,\n        99,\n        627,\n        64,\n        554,\n        379,\n        245,\n        152,\n        629,\n        276,\n        127,\n        470,\n        251,\n        810,\n        482,\n        308,\n        579,\n        560,\n        307,\n        928,\n        505,\n        683,\n        153,\n        711,\n        750,\n        357,\n        539,\n        66,\n        710,\n        701,\n        341,\n        513,\n        457,\n        146,\n        492,\n        111,\n        485,\n        348,\n        529,\n        574,\n        959,\n        469,\n        322,\n        759,\n        356,\n        552,\n        830,\n        43,\n        299,\n        120,\n        91,\n        398,\n        649,\n        948,\n        648,\n        968,\n        365,\n        645,\n        463,\n        443,\n        440,\n        682,\n        373,\n        865,\n        737,\n        733,\n        873,\n        662,\n        924,\n        216,\n        897,\n        651,\n        158,\n        454,\n        805,\n        839,\n        42,\n        660,\n        282,\n        2,\n        417,\n        745,\n        993,\n        894,\n        387,\n        229,\n        736,\n        821,\n        189,\n        198,\n        556,\n        106,\n        105,\n        899,\n        70,\n        517,\n        950,\n        466,\n        515,\n        867,\n        508,\n        576,\n        151,\n        807,\n        614,\n        9,\n        582,\n        680,\n        779,\n        228,\n        864,\n        249,\n        580,\n        831,\n        405,\n        10,\n        563,\n        964,\n        49,\n        488,\n        878,\n        687,\n        448,\n        871,\n        301,\n        140,\n        626,\n        434,\n        747,\n        378,\n        311,\n        284,\n        312,\n        377,\n        361,\n        911,\n        499,\n        624,\n        525,\n        80,\n        101,\n        654,\n        283,\n        715,\n        390,\n        87,\n        549,\n        295,\n        306,\n        325,\n        547,\n        461,\n        600,\n        26,\n        487,\n        817,\n        980,\n        785,\n        744,\n        795,\n        100,\n        854,\n        439,\n        532,\n        616,\n        969,\n        427,\n        798,\n        13,\n        901,\n        264,\n        410,\n        286,\n        63,\n        29,\n        945,\n        33,\n        892,\n        797,\n        355,\n        726,\n        618,\n        135,\n        742,\n        350,\n        133,\n        208,\n        273,\n        806,\n        294,\n        51,\n        243,\n        15,\n        138,\n        309,\n        178,\n        890,\n        7,\n        991,\n        35,\n        123,\n        661,\n        659,\n        328,\n        545,\n        219,\n        12,\n        376,\n        166,\n        985,\n        882,\n        269,\n        967,\n        406,\n        354,\n        6,\n        344,\n        56,\n        584,\n        555,\n        129,\n        331,\n        612,\n        156,\n        436,\n        438,\n        758,\n        729,\n        848,\n        340,\n        925,\n        55,\n        587,\n        231,\n        951,\n        197,\n        210,\n        601,\n        842,\n        205,\n        802,\n        450,\n        813,\n        782,\n        278,\n        976,\n        835,\n        242,\n        611,\n        479,\n        116,\n        351,\n        569,\n        59,\n        415,\n        974,\n        511,\n        692,\n        889,\n        794,\n        735,\n        90,\n        73,\n        565,\n        190,\n        172,\n        23,\n        122,\n        126,\n        923,\n        920,\n        767,\n        646,\n        423,\n        934,\n        919,\n        128,\n        169,\n        787,\n        597,\n        48,\n        914,\n        665,\n        693,\n        594,\n        858,\n        900,\n        145,\n        975,\n        199,\n        714,\n        385,\n        363,\n        277,\n        717,\n        223,\n        333,\n        931,\n        183,\n        933,\n        179,\n        53,\n        850,\n        65,\n        425,\n        568,\n        780,\n        622,\n        244,\n        225,\n        613,\n        446,\n        709,\n        452,\n        202,\n        421,\n        809,\n        800,\n        343,\n        955,\n        318,\n        836,\n        396,\n        770,\n        360,\n        827,\n        712,\n        804,\n        755,\n        684,\n        304,\n        623,\n        778,\n        175,\n        761,\n        193,\n        431,\n        281,\n        137,\n        419,\n        490,\n        71,\n        608,\n        94,\n        191,\n        347,\n        97,\n        566,\n        673,\n        845,\n        81,\n        167,\n        837,\n        110,\n        327,\n        397,\n        786,\n        453,\n        944,\n        149,\n        460,\n        572,\n        85,\n        679,\n        713,\n        504,\n        447,\n        182,\n        163,\n        109,\n        704,\n        194,\n        775,\n        637,\n        699,\n        994,\n        789,\n        142,\n        267,\n        641,\n        287,\n        234,\n        119,\n        705,\n        918,\n        45,\n        321,\n        258,\n        408,\n        76,\n        271,\n        305,\n        200,\n        672,\n        324,\n        632,\n        47,\n        430,\n        604,\n        201,\n        329,\n        334,\n        578,\n        407,\n        330,\n        884,\n        801,\n        606,\n        550,\n        132,\n        39,\n        751,\n        596,\n        696,\n        635,\n        247,\n        591,\n        888,\n        681,\n        776,\n        346,\n        445,\n        707,\n        226,\n        288,\n        957,\n        823,\n        615,\n        838,\n        187,\n        0,\n        239,\n        719,\n        617,\n        739,\n        79,\n        593,\n        486,\n        640,\n        875,\n        824,\n        546,\n        27,\n        819,\n        694,\n        442,\n        592,\n        403,\n        540,\n        506,\n        766,\n        207,\n        724,\n        480,\n        384,\n        388,\n        590,\n        518,\n        906,\n        314,\n        40,\n        14,\n        610,\n        983,\n        583,\n        136,\n        235,\n        171,\n        567,\n        881,\n        473,\n        248,\n        302,\n        386,\n        67,\n        656,\n        3,\n        358,\n        74,\n        209,\n        478,\n        940,\n        551,\n        493,\n        689,\n        853,\n        141,\n        908,\n        796,\n        577,\n        204,\n        586,\n        706,\n        220,\n        395,\n        221,\n        910,\n        17,\n        678,\n        946,\n        562,\n        896,\n        977,\n        320,\n        11,\n        391,\n        412,\n        192,\n        774,\n        83,\n        338,\n        876,\n        483,\n        666,\n        816,\n        414,\n        280,\n        494,\n        54,\n        937,\n        763,\n        184,\n        392,\n        30,\n        451,\n        477,\n        148,\n        367,\n        413,\n        240,\n        898,\n        161,\n        989,\n        263,\n        935,\n        254,\n        528,\n        336,\n        764,\n        720,\n        553,\n        159,\n        114,\n        293,\n        536,\n        825,\n        449,\n        337,\n        962,\n        399,\n        609,\n        131,\n        38,\n        88,\n        68,\n        426,\n        150,\n        558,\n        942,\n        364,\n        72,\n        893,\n        589,\n        599,\n        992,\n        437,\n        561,\n        783,\n        790,\n        811,\n        771,\n        634,\n        303,\n        503,\n        754,\n        489,\n        718,\n        495,\n        275,\n        383,\n        144,\n        468,\n        366,\n        297,\n        509,\n        857,\n        481,\n        721,\n        664,\n        691,\n        400,\n        342,\n        728,\n        125,\n        633,\n        260,\n        124,\n        404,\n        353,\n        749,\n        655,\n        401,\n        814,\n        585,\n        118,\n        196,\n        502,\n        224,\n        78,\n        663,\n        669,\n        768,\n        291,\n        5,\n        730,\n        628,\n        868,\n        213,\n        474,\n        497,\n        652,\n        420,\n        833,\n        462,\n        16,\n        203,\n        313,\n        702,\n        103,\n        422,\n        675,\n        788,\n        381,\n        296,\n        861,\n        625,\n        121,\n        531,\n        19,\n        374,\n        996,\n        96,\n        185,\n        841,\n        926,\n        428,\n        52,\n        727,\n        998,\n        902,\n        523,\n        435,\n        444,\n        852,\n        981,\n        953,\n        657,\n        548,\n        317,\n        971,\n        475,\n        986,\n        533,\n        877,\n        222,\n        107,\n        738,\n        932,\n        20,\n        102,\n        769,\n        246,\n        89,\n        862,\n        113,\n        164,\n        418,\n        393,\n        961,\n        154,\n        799,\n        949,\n        907,\n        832,\n        860,\n        262,\n        826,\n        859,\n        639,\n        233,\n        843,\n        82,\n        256,\n        75,\n        965,\n        731,\n        238,\n        252,\n        829,\n        725,\n        520,\n        237,\n        650,\n        464,\n        98,\n        174,\n        917,\n        134,\n        34,\n        259,\n        987,\n        686,\n        143,\n        571,\n        886,\n        18,\n        846,\n        429,\n        982,\n        268,\n        265,\n        885,\n        147,\n        335,\n        904,\n        960,\n        973,\n        598,\n        872,\n        812,\n        458,\n        972,\n        104,\n        323,\n        703,\n        676,\n        95,\n        230,\n        484,\n        157,\n        722,\n        636,\n        411,\n        773,\n        270,\n        46,\n        757,\n        619,\n        784,\n        564,\n        459,\n        984,\n        31,\n        863,\n        345,\n        292,\n        115,\n        765,\n        760,\n        642,\n        995,\n        352\n    ]\n}"
  },
  {
    "path": "datainfo/fire/data_split_dict_3.json",
    "content": "{\n    \"test\": [\n        712,\n        959,\n        987,\n        286,\n        876,\n        29,\n        698,\n        344,\n        968,\n        598,\n        417,\n        599,\n        546,\n        359,\n        587,\n        395,\n        853,\n        519,\n        431,\n        952,\n        875,\n        641,\n        797,\n        446,\n        688,\n        264,\n        222,\n        506,\n        139,\n        36,\n        99,\n        374,\n        943,\n        137,\n        455,\n        590,\n        820,\n        745,\n        404,\n        437,\n        807,\n        731,\n        396,\n        899,\n        942,\n        736,\n        609,\n        484,\n        275,\n        886,\n        843,\n        31,\n        798,\n        308,\n        43,\n        605,\n        776,\n        163,\n        65,\n        795,\n        687,\n        15,\n        759,\n        399,\n        535,\n        948,\n        888,\n        155,\n        650,\n        237,\n        154,\n        881,\n        654,\n        406,\n        487,\n        562,\n        856,\n        553,\n        481,\n        734,\n        196,\n        239,\n        564,\n        265,\n        480,\n        857,\n        930,\n        13,\n        620,\n        67,\n        594,\n        640,\n        485,\n        618,\n        937,\n        378,\n        133,\n        557,\n        606,\n        243\n    ],\n    \"val\": [\n        603,\n        630,\n        830,\n        870,\n        322,\n        593,\n        866,\n        11,\n        835,\n        561,\n        310,\n        533,\n        924,\n        447,\n        918,\n        998,\n        732,\n        737,\n        649,\n        441,\n        277,\n        916,\n        635,\n        105,\n        572,\n        697,\n        945,\n        659,\n        914,\n        532,\n        471,\n        385,\n        859,\n        141,\n        368,\n        321,\n        347,\n        953,\n        706,\n        159,\n        269,\n        625,\n        298,\n        909,\n        202,\n        32,\n        548,\n        614,\n        110,\n        78,\n        7,\n        317,\n        928,\n        241,\n        517,\n        285,\n        981,\n        338,\n        600,\n        735,\n        386,\n        46,\n        779,\n        629,\n        619,\n        45,\n        121,\n        425,\n        787,\n        938,\n        300,\n        20,\n        969,\n        420,\n        68,\n        819,\n        352,\n        90,\n        495,\n        971,\n        874,\n        493,\n        64,\n        127,\n        291,\n        273,\n        913,\n        851,\n        648,\n        216,\n        671,\n        730,\n        106,\n        582,\n        585,\n        554,\n        334,\n        970,\n        715,\n        167\n    ],\n    \"train\": [\n        316,\n        973,\n        262,\n        342,\n        904,\n        898,\n        566,\n        511,\n        871,\n        143,\n        176,\n        910,\n        644,\n        290,\n        927,\n        988,\n        624,\n        185,\n        370,\n        720,\n        156,\n        346,\n        388,\n        313,\n        616,\n        148,\n        891,\n        371,\n        993,\n        567,\n        832,\n        249,\n        272,\n        415,\n        529,\n        592,\n        472,\n        513,\n        749,\n        261,\n        664,\n        844,\n        848,\n        538,\n        236,\n        926,\n        847,\n        25,\n        777,\n        89,\n        147,\n        460,\n        873,\n        401,\n        219,\n        209,\n        115,\n        288,\n        684,\n        578,\n        826,\n        457,\n        174,\n        267,\n        576,\n        152,\n        432,\n        966,\n        362,\n        686,\n        520,\n        142,\n        746,\n        668,\n        145,\n        470,\n        98,\n        834,\n        709,\n        449,\n        908,\n        884,\n        192,\n        312,\n        889,\n        718,\n        71,\n        950,\n        690,\n        38,\n        974,\n        210,\n        704,\n        218,\n        421,\n        855,\n        380,\n        658,\n        536,\n        42,\n        303,\n        541,\n        762,\n        456,\n        207,\n        450,\n        402,\n        766,\n        579,\n        259,\n        551,\n        413,\n        412,\n        767,\n        982,\n        309,\n        198,\n        583,\n        166,\n        550,\n        915,\n        2,\n        922,\n        773,\n        393,\n        846,\n        179,\n        639,\n        217,\n        785,\n        814,\n        377,\n        713,\n        983,\n        500,\n        256,\n        390,\n        40,\n        479,\n        365,\n        355,\n        699,\n        733,\n        66,\n        476,\n        238,\n        463,\n        758,\n        491,\n        464,\n        114,\n        337,\n        328,\n        540,\n        925,\n        655,\n        893,\n        112,\n        738,\n        228,\n        466,\n        552,\n        279,\n        771,\n        53,\n        108,\n        199,\n        17,\n        865,\n        225,\n        100,\n        41,\n        701,\n        201,\n        742,\n        10,\n        907,\n        80,\n        911,\n        788,\n        178,\n        407,\n        315,\n        16,\n        627,\n        266,\n        190,\n        301,\n        8,\n        242,\n        661,\n        503,\n        235,\n        177,\n        443,\n        84,\n        504,\n        394,\n        790,\n        944,\n        253,\n        19,\n        622,\n        233,\n        521,\n        507,\n        6,\n        12,\n        809,\n        783,\n        646,\n        119,\n        725,\n        434,\n        168,\n        319,\n        144,\n        565,\n        113,\n        165,\n        366,\n        292,\n        9,\n        663,\n        251,\n        601,\n        350,\n        757,\n        414,\n        392,\n        756,\n        636,\n        645,\n        59,\n        462,\n        193,\n        197,\n        693,\n        0,\n        281,\n        186,\n        48,\n        314,\n        638,\n        405,\n        287,\n        220,\n        424,\n        555,\n        282,\n        794,\n        739,\n        364,\n        901,\n        162,\n        158,\n        900,\n        379,\n        208,\n        689,\n        985,\n        840,\n        324,\n        505,\n        103,\n        325,\n        716,\n        743,\n        672,\n        74,\n        957,\n        302,\n        213,\n        96,\n        545,\n        571,\n        990,\n        478,\n        442,\n        77,\n        837,\n        991,\n        526,\n        813,\n        633,\n        595,\n        206,\n        170,\n        770,\n        468,\n        559,\n        710,\n        518,\n        628,\n        861,\n        474,\n        760,\n        979,\n        138,\n        452,\n        863,\n        558,\n        363,\n        294,\n        490,\n        283,\n        92,\n        356,\n        408,\n        810,\n        824,\n        422,\n        234,\n        418,\n        153,\n        284,\n        215,\n        774,\n        769,\n        72,\n        902,\n        55,\n        63,\n        681,\n        768,\n        91,\n        967,\n        611,\n        391,\n        975,\n        150,\n        30,\n        180,\n        260,\n        252,\n        58,\n        839,\n        274,\n        289,\n        858,\n        890,\n        754,\n        805,\n        339,\n        729,\n        727,\n        188,\n        939,\n        748,\n        919,\n        486,\n        653,\n        764,\n        224,\n        57,\n        221,\n        793,\n        792,\n        140,\n        796,\n        956,\n        278,\n        896,\n        200,\n        574,\n        453,\n        667,\n        761,\n        39,\n        248,\n        194,\n        573,\n        169,\n        171,\n        960,\n        989,\n        172,\n        935,\n        842,\n        613,\n        231,\n        670,\n        933,\n        149,\n        135,\n        621,\n        60,\n        589,\n        373,\n        675,\n        965,\n        93,\n        610,\n        416,\n        86,\n        516,\n        575,\n        821,\n        958,\n        318,\n        509,\n        992,\n        931,\n        410,\n        97,\n        666,\n        596,\n        451,\n        674,\n        510,\n        397,\n        367,\n        331,\n        120,\n        660,\n        702,\n        23,\n        375,\n        400,\n        696,\n        862,\n        679,\n        246,\n        14,\n        719,\n        895,\n        872,\n        427,\n        703,\n        980,\n        651,\n        534,\n        607,\n        812,\n        977,\n        214,\n        570,\n        941,\n        483,\n        164,\n        955,\n        780,\n        612,\n        741,\n        528,\n        604,\n        411,\n        920,\n        932,\n        806,\n        157,\n        683,\n        497,\n        489,\n        482,\n        458,\n        304,\n        591,\n        868,\n        789,\n        343,\n        151,\n        962,\n        631,\n        502,\n        569,\n        118,\n        161,\n        263,\n        299,\n        75,\n        586,\n        996,\n        921,\n        680,\n        95,\n        439,\n        833,\n        869,\n        951,\n        496,\n        357,\n        469,\n        617,\n        563,\n        632,\n        49,\n        83,\n        76,\n        454,\n        459,\n        205,\n        740,\n        673,\n        685,\n        351,\n        811,\n        473,\n        345,\n        296,\n        387,\n        963,\n        744,\n        212,\n        82,\n        877,\n        721,\n        711,\n        211,\n        560,\n        498,\n        240,\n        543,\n        85,\n        577,\n        358,\n        972,\n        376,\n        652,\n        403,\n        465,\n        255,\n        525,\n        247,\n        79,\n        184,\n        332,\n        705,\n        102,\n        109,\n        662,\n        523,\n        37,\n        903,\n        306,\n        883,\n        880,\n        117,\n        26,\n        808,\n        878,\n        24,\n        183,\n        384,\n        101,\n        94,\n        800,\n        864,\n        954,\n        852,\n        146,\n        656,\n        5,\n        822,\n        544,\n        326,\n        111,\n        801,\n        724,\n        597,\n        707,\n        124,\n        268,\n        383,\n        333,\n        28,\n        995,\n        436,\n        816,\n        435,\n        608,\n        276,\n        845,\n        940,\n        676,\n        18,\n        885,\n        428,\n        430,\n        765,\n        508,\n        381,\n        717,\n        818,\n        44,\n        804,\n        894,\n        829,\n        382,\n        836,\n        182,\n        580,\n        978,\n        389,\n        3,\n        568,\n        130,\n        126,\n        438,\n        336,\n        867,\n        271,\n        369,\n        311,\n        984,\n        398,\n        827,\n        912,\n        73,\n        722,\n        160,\n        323,\n        784,\n        514,\n        527,\n        88,\n        923,\n        494,\n        70,\n        897,\n        882,\n        47,\n        129,\n        56,\n        1,\n        330,\n        946,\n        52,\n        433,\n        828,\n        539,\n        751,\n        254,\n        708,\n        51,\n        409,\n        584,\n        499,\n        849,\n        131,\n        280,\n        245,\n        815,\n        677,\n        547,\n        125,\n        949,\n        226,\n        360,\n        781,\n        747,\n        976,\n        258,\n        128,\n        682,\n        116,\n        626,\n        54,\n        905,\n        549,\n        537,\n        802,\n        750,\n        763,\n        515,\n        175,\n        726,\n        492,\n        986,\n        50,\n        934,\n        440,\n        203,\n        524,\n        35,\n        860,\n        250,\n        122,\n        293,\n        602,\n        69,\n        232,\n        349,\n        556,\n        295,\n        531,\n        678,\n        775,\n        33,\n        753,\n        823,\n        444,\n        522,\n        637,\n        488,\n        4,\n        204,\n        838,\n        423,\n        964,\n        419,\n        581,\n        429,\n        297,\n        426,\n        817,\n        354,\n        475,\n        728,\n        530,\n        841,\n        917,\n        132,\n        782,\n        353,\n        634,\n        87,\n        657,\n        348,\n        786,\n        187,\n        588,\n        803,\n        195,\n        542,\n        34,\n        123,\n        230,\n        879,\n        461,\n        961,\n        223,\n        936,\n        906,\n        81,\n        173,\n        448,\n        229,\n        691,\n        341,\n        329,\n        772,\n        104,\n        929,\n        714,\n        665,\n        445,\n        694,\n        191,\n        335,\n        244,\n        947,\n        669,\n        227,\n        512,\n        850,\n        134,\n        752,\n        700,\n        892,\n        27,\n        107,\n        831,\n        307,\n        270,\n        825,\n        778,\n        320,\n        189,\n        372,\n        181,\n        327,\n        615,\n        997,\n        305,\n        467,\n        643,\n        257,\n        994,\n        21,\n        692,\n        62,\n        799,\n        22,\n        501,\n        755,\n        854,\n        723,\n        623,\n        791,\n        695,\n        361,\n        477,\n        340,\n        642,\n        887,\n        61,\n        136,\n        647\n    ]\n}"
  },
  {
    "path": "datainfo/lava_lamp/data_split_dict_1.json",
    "content": "{\n    \"test\": [\n        408,\n        380,\n        248,\n        491,\n        35,\n        402,\n        511,\n        290,\n        310,\n        194,\n        519,\n        539,\n        512,\n        340,\n        123,\n        303,\n        190,\n        102,\n        426,\n        544,\n        296,\n        470,\n        224,\n        236,\n        353,\n        238,\n        561,\n        448,\n        227,\n        554,\n        432,\n        221,\n        390,\n        9,\n        26,\n        22,\n        31,\n        325,\n        104,\n        234,\n        272,\n        456,\n        2,\n        443,\n        399,\n        29,\n        499,\n        96,\n        214,\n        388,\n        483,\n        460,\n        507,\n        120,\n        261,\n        64,\n        137\n    ],\n    \"val\": [\n        463,\n        337,\n        565,\n        235,\n        180,\n        295,\n        433,\n        176,\n        263,\n        207,\n        118,\n        503,\n        440,\n        276,\n        526,\n        394,\n        6,\n        116,\n        257,\n        86,\n        87,\n        331,\n        487,\n        529,\n        524,\n        314,\n        434,\n        360,\n        157,\n        528,\n        240,\n        15,\n        375,\n        250,\n        189,\n        201,\n        430,\n        266,\n        83,\n        398,\n        55,\n        260,\n        339,\n        531,\n        44,\n        191,\n        377,\n        345,\n        397,\n        359,\n        451,\n        280,\n        187,\n        88,\n        522,\n        212,\n        206\n    ],\n    \"train\": [\n        126,\n        245,\n        516,\n        84,\n        527,\n        351,\n        124,\n        165,\n        462,\n        504,\n        393,\n        447,\n        352,\n        54,\n        385,\n        251,\n        155,\n        52,\n        327,\n        160,\n        244,\n        404,\n        498,\n        453,\n        392,\n        67,\n        396,\n        131,\n        76,\n        21,\n        475,\n        163,\n        365,\n        274,\n        154,\n        422,\n        110,\n        413,\n        128,\n        558,\n        217,\n        184,\n        241,\n        362,\n        17,\n        195,\n        464,\n        279,\n        326,\n        147,\n        89,\n        249,\n        496,\n        216,\n        59,\n        410,\n        342,\n        103,\n        106,\n        320,\n        73,\n        505,\n        378,\n        381,\n        457,\n        334,\n        142,\n        318,\n        72,\n        530,\n        341,\n        292,\n        482,\n        545,\n        304,\n        253,\n        537,\n        161,\n        38,\n        98,\n        239,\n        523,\n        79,\n        95,\n        316,\n        284,\n        145,\n        192,\n        285,\n        172,\n        135,\n        247,\n        409,\n        13,\n        324,\n        555,\n        478,\n        328,\n        209,\n        374,\n        133,\n        65,\n        119,\n        357,\n        113,\n        298,\n        162,\n        368,\n        233,\n        489,\n        60,\n        461,\n        45,\n        101,\n        306,\n        370,\n        1,\n        223,\n        277,\n        121,\n        226,\n        125,\n        497,\n        220,\n        371,\n        288,\n        51,\n        552,\n        5,\n        541,\n        134,\n        495,\n        185,\n        267,\n        62,\n        77,\n        550,\n        47,\n        490,\n        178,\n        536,\n        407,\n        335,\n        494,\n        210,\n        506,\n        481,\n        71,\n        348,\n        406,\n        63,\n        476,\n        78,\n        181,\n        367,\n        405,\n        317,\n        91,\n        256,\n        486,\n        333,\n        361,\n        458,\n        435,\n        469,\n        153,\n        308,\n        168,\n        418,\n        286,\n        183,\n        100,\n        198,\n        179,\n        343,\n        450,\n        449,\n        468,\n        508,\n        321,\n        502,\n        297,\n        25,\n        315,\n        122,\n        425,\n        329,\n        500,\n        376,\n        138,\n        415,\n        312,\n        301,\n        525,\n        40,\n        204,\n        560,\n        553,\n        354,\n        208,\n        41,\n        358,\n        532,\n        338,\n        61,\n        48,\n        421,\n        444,\n        222,\n        417,\n        23,\n        557,\n        171,\n        237,\n        3,\n        225,\n        428,\n        11,\n        243,\n        27,\n        70,\n        412,\n        534,\n        452,\n        229,\n        75,\n        93,\n        269,\n        271,\n        389,\n        57,\n        140,\n        146,\n        442,\n        363,\n        200,\n        484,\n        480,\n        538,\n        309,\n        32,\n        424,\n        37,\n        474,\n        300,\n        97,\n        350,\n        547,\n        355,\n        349,\n        543,\n        366,\n        141,\n        467,\n        34,\n        485,\n        445,\n        149,\n        369,\n        174,\n        431,\n        188,\n        382,\n        347,\n        170,\n        391,\n        509,\n        384,\n        510,\n        68,\n        43,\n        20,\n        33,\n        562,\n        549,\n        427,\n        273,\n        387,\n        518,\n        49,\n        136,\n        109,\n        219,\n        173,\n        69,\n        167,\n        305,\n        80,\n        533,\n        144,\n        205,\n        166,\n        439,\n        255,\n        429,\n        151,\n        199,\n        53,\n        252,\n        99,\n        356,\n        293,\n        437,\n        50,\n        423,\n        488,\n        193,\n        477,\n        493,\n        472,\n        455,\n        520,\n        111,\n        302,\n        19,\n        446,\n        4,\n        441,\n        479,\n        289,\n        213,\n        383,\n        330,\n        158,\n        39,\n        400,\n        156,\n        24,\n        108,\n        564,\n        152,\n        30,\n        346,\n        323,\n        372,\n        294,\n        202,\n        559,\n        332,\n        268,\n        114,\n        230,\n        264,\n        322,\n        112,\n        278,\n        436,\n        259,\n        228,\n        82,\n        18,\n        74,\n        203,\n        542,\n        115,\n        10,\n        540,\n        517,\n        107,\n        563,\n        129,\n        492,\n        132,\n        556,\n        215,\n        175,\n        197,\n        159,\n        12,\n        58,\n        242,\n        254,\n        164,\n        501,\n        232,\n        150,\n        364,\n        473,\n        139,\n        336,\n        471,\n        270,\n        403,\n        81,\n        85,\n        513,\n        148,\n        459,\n        94,\n        419,\n        56,\n        454,\n        127,\n        143,\n        395,\n        386,\n        7,\n        231,\n        8,\n        42,\n        36,\n        344,\n        16,\n        130,\n        282,\n        46,\n        92,\n        299,\n        281,\n        90,\n        546,\n        117,\n        411,\n        14,\n        307,\n        548,\n        169,\n        313,\n        514,\n        319,\n        465,\n        275,\n        0,\n        177,\n        535,\n        182,\n        416,\n        515,\n        211,\n        258,\n        466,\n        283,\n        291,\n        186,\n        246,\n        28,\n        218,\n        105,\n        287,\n        521,\n        265,\n        66,\n        414,\n        262,\n        379,\n        420,\n        438,\n        401,\n        196,\n        551,\n        373,\n        311\n    ]\n}"
  },
  {
    "path": "datainfo/lava_lamp/data_split_dict_2.json",
    "content": "{\n    \"test\": [\n        462,\n        460,\n        465,\n        523,\n        512,\n        510,\n        285,\n        501,\n        255,\n        472,\n        409,\n        165,\n        528,\n        370,\n        362,\n        546,\n        424,\n        456,\n        186,\n        526,\n        368,\n        531,\n        522,\n        139,\n        177,\n        332,\n        180,\n        24,\n        236,\n        241,\n        181,\n        168,\n        538,\n        433,\n        389,\n        326,\n        476,\n        372,\n        28,\n        557,\n        274,\n        514,\n        455,\n        380,\n        521,\n        402,\n        441,\n        162,\n        36,\n        217,\n        257,\n        315,\n        173,\n        369,\n        86,\n        93,\n        57\n    ],\n    \"val\": [\n        30,\n        54,\n        381,\n        97,\n        495,\n        4,\n        505,\n        174,\n        420,\n        401,\n        38,\n        451,\n        319,\n        350,\n        187,\n        262,\n        250,\n        106,\n        374,\n        159,\n        208,\n        301,\n        489,\n        333,\n        259,\n        265,\n        287,\n        258,\n        425,\n        361,\n        519,\n        155,\n        158,\n        245,\n        466,\n        395,\n        137,\n        560,\n        464,\n        448,\n        85,\n        427,\n        358,\n        417,\n        166,\n        481,\n        113,\n        337,\n        249,\n        233,\n        530,\n        515,\n        471,\n        371,\n        290,\n        179,\n        537\n    ],\n    \"train\": [\n        89,\n        278,\n        542,\n        391,\n        303,\n        504,\n        223,\n        215,\n        320,\n        114,\n        559,\n        327,\n        377,\n        331,\n        488,\n        336,\n        76,\n        555,\n        14,\n        206,\n        539,\n        81,\n        96,\n        507,\n        153,\n        124,\n        150,\n        225,\n        470,\n        399,\n        284,\n        50,\n        307,\n        5,\n        33,\n        203,\n        509,\n        199,\n        541,\n        445,\n        341,\n        457,\n        218,\n        239,\n        273,\n        83,\n        335,\n        246,\n        19,\n        556,\n        322,\n        393,\n        224,\n        475,\n        387,\n        340,\n        385,\n        463,\n        352,\n        169,\n        357,\n        133,\n        252,\n        483,\n        90,\n        384,\n        230,\n        170,\n        508,\n        305,\n        411,\n        207,\n        244,\n        347,\n        296,\n        3,\n        79,\n        143,\n        421,\n        440,\n        423,\n        379,\n        189,\n        275,\n        221,\n        256,\n        408,\n        175,\n        243,\n        517,\n        458,\n        355,\n        121,\n        163,\n        183,\n        61,\n        55,\n        563,\n        272,\n        1,\n        188,\n        397,\n        110,\n        62,\n        228,\n        392,\n        182,\n        529,\n        254,\n        449,\n        102,\n        151,\n        482,\n        104,\n        375,\n        478,\n        498,\n        365,\n        443,\n        346,\n        324,\n        72,\n        312,\n        27,\n        141,\n        279,\n        496,\n        234,\n        253,\n        154,\n        160,\n        413,\n        410,\n        551,\n        144,\n        291,\n        439,\n        487,\n        506,\n        364,\n        292,\n        516,\n        105,\n        64,\n        235,\n        70,\n        220,\n        13,\n        415,\n        240,\n        46,\n        545,\n        7,\n        400,\n        339,\n        264,\n        394,\n        193,\n        103,\n        286,\n        311,\n        518,\n        43,\n        271,\n        329,\n        63,\n        454,\n        238,\n        195,\n        406,\n        138,\n        396,\n        547,\n        363,\n        219,\n        317,\n        297,\n        84,\n        407,\n        171,\n        398,\n        99,\n        280,\n        32,\n        281,\n        520,\n        212,\n        405,\n        511,\n        544,\n        293,\n        491,\n        45,\n        226,\n        323,\n        536,\n        147,\n        328,\n        149,\n        21,\n        202,\n        561,\n        178,\n        109,\n        140,\n        66,\n        152,\n        213,\n        40,\n        479,\n        342,\n        75,\n        513,\n        438,\n        300,\n        383,\n        527,\n        330,\n        122,\n        360,\n        316,\n        68,\n        95,\n        204,\n        548,\n        117,\n        91,\n        74,\n        426,\n        535,\n        429,\n        11,\n        198,\n        120,\n        540,\n        304,\n        6,\n        295,\n        533,\n        444,\n        100,\n        164,\n        492,\n        298,\n        432,\n        419,\n        247,\n        366,\n        502,\n        31,\n        412,\n        276,\n        493,\n        192,\n        35,\n        270,\n        200,\n        416,\n        59,\n        98,\n        251,\n        112,\n        39,\n        277,\n        283,\n        308,\n        145,\n        480,\n        499,\n        469,\n        306,\n        248,\n        210,\n        435,\n        231,\n        8,\n        101,\n        156,\n        359,\n        314,\n        211,\n        288,\n        390,\n        190,\n        148,\n        289,\n        60,\n        562,\n        356,\n        486,\n        485,\n        48,\n        92,\n        268,\n        494,\n        26,\n        313,\n        261,\n        558,\n        222,\n        436,\n        108,\n        194,\n        549,\n        484,\n        345,\n        237,\n        266,\n        524,\n        111,\n        53,\n        353,\n        564,\n        553,\n        51,\n        418,\n        123,\n        44,\n        467,\n        56,\n        348,\n        209,\n        196,\n        554,\n        403,\n        461,\n        269,\n        142,\n        434,\n        131,\n        428,\n        534,\n        490,\n        446,\n        447,\n        41,\n        128,\n        37,\n        227,\n        119,\n        404,\n        431,\n        260,\n        118,\n        325,\n        232,\n        49,\n        87,\n        82,\n        67,\n        17,\n        129,\n        430,\n        343,\n        71,\n        503,\n        9,\n        450,\n        214,\n        310,\n        134,\n        132,\n        459,\n        73,\n        167,\n        263,\n        500,\n        201,\n        299,\n        477,\n        414,\n        543,\n        550,\n        52,\n        161,\n        351,\n        338,\n        47,\n        115,\n        242,\n        78,\n        497,\n        318,\n        205,\n        135,\n        23,\n        378,\n        309,\n        282,\n        229,\n        157,\n        15,\n        468,\n        172,\n        146,\n        565,\n        382,\n        552,\n        321,\n        474,\n        176,\n        2,\n        18,\n        77,\n        126,\n        22,\n        473,\n        197,\n        0,\n        354,\n        442,\n        94,\n        376,\n        80,\n        65,\n        130,\n        191,\n        10,\n        373,\n        20,\n        525,\n        34,\n        58,\n        42,\n        12,\n        344,\n        127,\n        88,\n        184,\n        185,\n        29,\n        16,\n        388,\n        367,\n        216,\n        452,\n        107,\n        422,\n        125,\n        136,\n        437,\n        69,\n        267,\n        386,\n        453,\n        349,\n        116,\n        302,\n        532,\n        25,\n        334,\n        294\n    ]\n}"
  },
  {
    "path": "datainfo/lava_lamp/data_split_dict_3.json",
    "content": "{\n    \"test\": [\n        425,\n        324,\n        216,\n        106,\n        334,\n        167,\n        286,\n        29,\n        344,\n        549,\n        417,\n        359,\n        395,\n        519,\n        431,\n        541,\n        446,\n        264,\n        222,\n        506,\n        139,\n        36,\n        99,\n        374,\n        137,\n        455,\n        404,\n        437,\n        396,\n        484,\n        275,\n        31,\n        308,\n        43,\n        163,\n        65,\n        15,\n        399,\n        535,\n        155,\n        237,\n        154,\n        406,\n        487,\n        553,\n        481,\n        196,\n        239,\n        265,\n        480,\n        13,\n        67,\n        485,\n        378,\n        133,\n        557,\n        243\n    ],\n    \"val\": [\n        173,\n        444,\n        21,\n        353,\n        79,\n        134,\n        312,\n        149,\n        208,\n        101,\n        16,\n        274,\n        307,\n        55,\n        39,\n        3,\n        158,\n        18,\n        120,\n        258,\n        142,\n        472,\n        451,\n        282,\n        169,\n        300,\n        367,\n        193,\n        23,\n        389,\n        314,\n        309,\n        22,\n        60,\n        525,\n        212,\n        393,\n        218,\n        150,\n        10,\n        77,\n        459,\n        210,\n        34,\n        409,\n        176,\n        45,\n        247,\n        327,\n        536,\n        246,\n        32,\n        63,\n        145,\n        136,\n        293,\n        505\n    ],\n    \"train\": [\n        188,\n        28,\n        326,\n        171,\n        483,\n        422,\n        494,\n        319,\n        440,\n        305,\n        59,\n        355,\n        379,\n        299,\n        73,\n        383,\n        178,\n        54,\n        376,\n        331,\n        352,\n        454,\n        234,\n        421,\n        112,\n        42,\n        442,\n        529,\n        354,\n        445,\n        342,\n        364,\n        46,\n        187,\n        565,\n        126,\n        103,\n        436,\n        181,\n        104,\n        199,\n        35,\n        41,\n        450,\n        287,\n        191,\n        38,\n        449,\n        427,\n        539,\n        156,\n        526,\n        358,\n        270,\n        382,\n        517,\n        313,\n        388,\n        325,\n        85,\n        24,\n        555,\n        458,\n        242,\n        402,\n        466,\n        426,\n        272,\n        423,\n        37,\n        411,\n        443,\n        227,\n        292,\n        100,\n        175,\n        273,\n        298,\n        322,\n        470,\n        75,\n        56,\n        335,\n        91,\n        504,\n        551,\n        424,\n        384,\n        51,\n        119,\n        19,\n        231,\n        141,\n        500,\n        372,\n        74,\n        229,\n        88,\n        512,\n        205,\n        168,\n        520,\n        283,\n        279,\n        511,\n        202,\n        375,\n        71,\n        44,\n        269,\n        195,\n        47,\n        419,\n        50,\n        297,\n        351,\n        78,\n        276,\n        263,\n        248,\n        124,\n        410,\n        252,\n        105,\n        302,\n        157,\n        394,\n        476,\n        545,\n        435,\n        380,\n        509,\n        130,\n        9,\n        232,\n        76,\n        72,\n        236,\n        516,\n        362,\n        385,\n        387,\n        288,\n        206,\n        337,\n        471,\n        254,\n        336,\n        253,\n        194,\n        221,\n        561,\n        530,\n        260,\n        1,\n        330,\n        26,\n        528,\n        144,\n        81,\n        277,\n        96,\n        562,\n        524,\n        118,\n        320,\n        217,\n        143,\n        190,\n        83,\n        7,\n        33,\n        318,\n        57,\n        251,\n        225,\n        151,\n        69,\n        255,\n        465,\n        267,\n        4,\n        200,\n        460,\n        390,\n        107,\n        226,\n        179,\n        369,\n        381,\n        201,\n        398,\n        110,\n        159,\n        240,\n        162,\n        207,\n        284,\n        166,\n        428,\n        117,\n        363,\n        281,\n        498,\n        508,\n        20,\n        203,\n        338,\n        185,\n        183,\n        245,\n        478,\n        241,\n        556,\n        109,\n        84,\n        48,\n        371,\n        92,\n        492,\n        490,\n        531,\n        165,\n        98,\n        89,\n        414,\n        228,\n        349,\n        80,\n        295,\n        391,\n        131,\n        538,\n        182,\n        489,\n        8,\n        123,\n        507,\n        12,\n        339,\n        397,\n        503,\n        499,\n        14,\n        219,\n        0,\n        316,\n        198,\n        82,\n        429,\n        532,\n        108,\n        563,\n        457,\n        370,\n        430,\n        127,\n        412,\n        289,\n        204,\n        249,\n        467,\n        544,\n        140,\n        475,\n        285,\n        356,\n        62,\n        534,\n        306,\n        420,\n        521,\n        310,\n        129,\n        64,\n        477,\n        58,\n        27,\n        482,\n        463,\n        268,\n        488,\n        365,\n        546,\n        257,\n        87,\n        418,\n        502,\n        6,\n        25,\n        433,\n        461,\n        262,\n        340,\n        558,\n        125,\n        341,\n        146,\n        495,\n        116,\n        333,\n        278,\n        147,\n        496,\n        554,\n        462,\n        523,\n        439,\n        401,\n        261,\n        244,\n        2,\n        102,\n        456,\n        211,\n        469,\n        209,\n        290,\n        214,\n        148,\n        213,\n        177,\n        518,\n        343,\n        564,\n        493,\n        215,\n        66,\n        537,\n        497,\n        501,\n        542,\n        328,\n        174,\n        400,\n        93,\n        294,\n        522,\n        97,\n        271,\n        17,\n        61,\n        115,\n        434,\n        230,\n        373,\n        111,\n        360,\n        172,\n        40,\n        86,\n        224,\n        114,\n        345,\n        407,\n        164,\n        386,\n        438,\n        49,\n        357,\n        332,\n        527,\n        547,\n        95,\n        514,\n        122,\n        533,\n        513,\n        113,\n        256,\n        468,\n        560,\n        350,\n        291,\n        559,\n        53,\n        447,\n        153,\n        135,\n        448,\n        392,\n        474,\n        94,\n        186,\n        90,\n        543,\n        464,\n        303,\n        152,\n        233,\n        408,\n        128,\n        189,\n        416,\n        346,\n        540,\n        413,\n        11,\n        250,\n        377,\n        473,\n        361,\n        311,\n        405,\n        347,\n        180,\n        238,\n        170,\n        321,\n        432,\n        30,\n        68,\n        323,\n        301,\n        315,\n        486,\n        491,\n        161,\n        296,\n        552,\n        403,\n        5,\n        452,\n        280,\n        548,\n        453,\n        132,\n        223,\n        550,\n        121,\n        366,\n        368,\n        510,\n        220,\n        138,\n        259,\n        415,\n        317,\n        52,\n        515,\n        348,\n        304,\n        329,\n        197,\n        266,\n        235,\n        192,\n        479,\n        441,\n        70,\n        184,\n        160\n    ]\n}"
  },
  {
    "path": "datainfo/reaction_diffusion/data_split_dict_1.json",
    "content": "{\n    \"test\": [\n        83,\n        60,\n        57,\n        63,\n        15,\n        32,\n        8,\n        97,\n        72,\n        17\n    ],\n    \"val\": [\n        92,\n        0,\n        77,\n        55,\n        49,\n        3,\n        62,\n        12,\n        26,\n        48\n    ],\n    \"train\": [\n        53,\n        37,\n        65,\n        51,\n        4,\n        20,\n        38,\n        9,\n        10,\n        81,\n        44,\n        36,\n        84,\n        50,\n        96,\n        90,\n        66,\n        16,\n        80,\n        33,\n        24,\n        52,\n        91,\n        99,\n        64,\n        5,\n        58,\n        76,\n        39,\n        79,\n        23,\n        94,\n        30,\n        73,\n        25,\n        47,\n        31,\n        45,\n        19,\n        87,\n        42,\n        68,\n        95,\n        21,\n        7,\n        67,\n        46,\n        82,\n        11,\n        6,\n        41,\n        86,\n        88,\n        70,\n        18,\n        78,\n        71,\n        59,\n        43,\n        61,\n        22,\n        14,\n        35,\n        93,\n        56,\n        28,\n        98,\n        54,\n        27,\n        89,\n        1,\n        69,\n        74,\n        2,\n        85,\n        40,\n        13,\n        75,\n        29,\n        34\n    ]\n}"
  },
  {
    "path": "datainfo/reaction_diffusion/data_split_dict_2.json",
    "content": "{\n    \"test\": [\n        77,\n        32,\n        39,\n        85,\n        94,\n        21,\n        46,\n        10,\n        11,\n        7\n    ],\n    \"val\": [\n        47,\n        65,\n        50,\n        81,\n        55,\n        20,\n        74,\n        4,\n        90,\n        27\n    ],\n    \"train\": [\n        0,\n        76,\n        61,\n        63,\n        1,\n        71,\n        2,\n        6,\n        16,\n        19,\n        13,\n        24,\n        49,\n        12,\n        75,\n        9,\n        83,\n        72,\n        5,\n        41,\n        99,\n        45,\n        89,\n        53,\n        79,\n        18,\n        52,\n        92,\n        14,\n        42,\n        68,\n        44,\n        38,\n        84,\n        36,\n        17,\n        31,\n        15,\n        70,\n        88,\n        25,\n        97,\n        51,\n        73,\n        66,\n        37,\n        78,\n        33,\n        80,\n        26,\n        82,\n        28,\n        60,\n        35,\n        43,\n        57,\n        23,\n        58,\n        91,\n        8,\n        62,\n        93,\n        98,\n        86,\n        29,\n        30,\n        22,\n        95,\n        67,\n        54,\n        48,\n        40,\n        59,\n        96,\n        3,\n        87,\n        34,\n        64,\n        56,\n        69\n    ]\n}"
  },
  {
    "path": "datainfo/reaction_diffusion/data_split_dict_3.json",
    "content": "{\n    \"test\": [\n        8,\n        74,\n        80,\n        60,\n        77,\n        47,\n        16,\n        69,\n        75,\n        30\n    ],\n    \"val\": [\n        85,\n        97,\n        87,\n        24,\n        29,\n        70,\n        33,\n        93,\n        1,\n        94\n    ],\n    \"train\": [\n        35,\n        41,\n        45,\n        4,\n        76,\n        12,\n        0,\n        64,\n        42,\n        11,\n        81,\n        39,\n        7,\n        48,\n        78,\n        15,\n        53,\n        62,\n        9,\n        52,\n        68,\n        55,\n        63,\n        65,\n        61,\n        67,\n        95,\n        18,\n        58,\n        86,\n        99,\n        92,\n        10,\n        17,\n        72,\n        21,\n        14,\n        44,\n        37,\n        91,\n        22,\n        59,\n        83,\n        32,\n        26,\n        98,\n        40,\n        27,\n        43,\n        96,\n        13,\n        31,\n        57,\n        2,\n        6,\n        23,\n        56,\n        71,\n        28,\n        36,\n        51,\n        46,\n        25,\n        54,\n        73,\n        79,\n        34,\n        3,\n        38,\n        5,\n        20,\n        90,\n        88,\n        49,\n        66,\n        89,\n        84,\n        19,\n        50,\n        82\n    ]\n}"
  },
  {
    "path": "datainfo/single_pendulum/data_split_dict_1.json",
    "content": "{\n    \"test\": [\n        187,\n        370,\n        362,\n        470,\n        57,\n        938,\n        678,\n        3,\n        708,\n        1137,\n        730,\n        993,\n        846,\n        1033,\n        409,\n        746,\n        985,\n        114,\n        872,\n        420,\n        1062,\n        264,\n        1049,\n        785,\n        11,\n        551,\n        940,\n        723,\n        704,\n        1052,\n        828,\n        475,\n        1105,\n        408,\n        25,\n        464,\n        1028,\n        345,\n        348,\n        806,\n        631,\n        89,\n        961,\n        60,\n        1002,\n        758,\n        1142,\n        1066,\n        335,\n        221,\n        1041,\n        898,\n        177,\n        767,\n        1123,\n        751,\n        354,\n        848,\n        827,\n        497,\n        983,\n        70,\n        805,\n        1034,\n        1022,\n        581,\n        621,\n        388,\n        1039,\n        1171,\n        1025,\n        681,\n        247,\n        607,\n        380,\n        204,\n        1139,\n        852,\n        44,\n        593,\n        941,\n        448,\n        472,\n        707,\n        477,\n        1132,\n        1015,\n        896,\n        454,\n        1080,\n        59,\n        864,\n        443,\n        780,\n        18,\n        1108,\n        52,\n        45,\n        62,\n        650,\n        209,\n        468,\n        545,\n        912,\n        4,\n        886,\n        798,\n        58,\n        999,\n        192,\n        429,\n        777,\n        967,\n        920,\n        1014,\n        241,\n        522,\n        129,\n        1165,\n        275\n    ],\n    \"val\": [\n        411,\n        892,\n        626,\n        333,\n        17,\n        1071,\n        516,\n        303,\n        399,\n        1151,\n        960,\n        106,\n        504,\n        198,\n        605,\n        1172,\n        918,\n        690,\n        587,\n        210,\n        101,\n        355,\n        205,\n        387,\n        1000,\n        521,\n        637,\n        720,\n        1186,\n        890,\n        888,\n        847,\n        175,\n        471,\n        583,\n        922,\n        1096,\n        1046,\n        839,\n        604,\n        38,\n        870,\n        899,\n        574,\n        8,\n        133,\n        258,\n        578,\n        426,\n        162,\n        761,\n        305,\n        1122,\n        939,\n        317,\n        78,\n        879,\n        1036,\n        313,\n        1053,\n        1167,\n        217,\n        991,\n        257,\n        611,\n        120,\n        1093,\n        657,\n        808,\n        1178,\n        457,\n        923,\n        451,\n        873,\n        1183,\n        328,\n        72,\n        299,\n        813,\n        36,\n        461,\n        42,\n        884,\n        428,\n        1044,\n        519,\n        222,\n        529,\n        385,\n        862,\n        703,\n        791,\n        638,\n        48,\n        233,\n        970,\n        1016,\n        659,\n        931,\n        603,\n        558,\n        344,\n        1079,\n        326,\n        342,\n        142,\n        594,\n        705,\n        378,\n        224,\n        550,\n        511,\n        575,\n        29,\n        927,\n        34,\n        170,\n        144,\n        66,\n        1196\n    ],\n    \"train\": [\n        483,\n        489,\n        341,\n        685,\n        96,\n        1,\n        1148,\n        542,\n        656,\n        744,\n        1085,\n        613,\n        855,\n        775,\n        680,\n        111,\n        1081,\n        648,\n        308,\n        733,\n        1023,\n        189,\n        759,\n        61,\n        482,\n        639,\n        1190,\n        228,\n        776,\n        596,\n        885,\n        1024,\n        1187,\n        1048,\n        295,\n        1158,\n        1146,\n        1094,\n        819,\n        684,\n        540,\n        793,\n        856,\n        1099,\n        622,\n        262,\n        773,\n        609,\n        107,\n        263,\n        260,\n        1160,\n        958,\n        474,\n        672,\n        739,\n        442,\n        478,\n        887,\n        1083,\n        734,\n        612,\n        1084,\n        956,\n        843,\n        55,\n        1067,\n        750,\n        823,\n        9,\n        532,\n        112,\n        282,\n        329,\n        520,\n        874,\n        359,\n        1125,\n        1102,\n        514,\n        668,\n        1164,\n        459,\n        197,\n        844,\n        753,\n        635,\n        633,\n        185,\n        1195,\n        966,\n        201,\n        906,\n        1059,\n        569,\n        146,\n        580,\n        39,\n        1047,\n        446,\n        597,\n        276,\n        43,\n        1042,\n        533,\n        561,\n        166,\n        1003,\n        154,\n        778,\n        647,\n        243,\n        67,\n        768,\n        617,\n        1149,\n        634,\n        981,\n        649,\n        450,\n        507,\n        215,\n        1057,\n        740,\n        691,\n        710,\n        1120,\n        858,\n        889,\n        682,\n        71,\n        531,\n        869,\n        415,\n        849,\n        444,\n        237,\n        716,\n        214,\n        764,\n        735,\n        765,\n        595,\n        486,\n        90,\n        834,\n        989,\n        437,\n        292,\n        1127,\n        549,\n        330,\n        16,\n        304,\n        559,\n        1168,\n        104,\n        186,\n        916,\n        719,\n        99,\n        623,\n        476,\n        293,\n        498,\n        286,\n        1136,\n        905,\n        503,\n        143,\n        952,\n        917,\n        167,\n        242,\n        26,\n        797,\n        149,\n        315,\n        652,\n        897,\n        871,\n        234,\n        271,\n        223,\n        1029,\n        769,\n        238,\n        673,\n        924,\n        294,\n        1199,\n        567,\n        33,\n        207,\n        821,\n        945,\n        795,\n        895,\n        548,\n        784,\n        729,\n        987,\n        412,\n        908,\n        986,\n        965,\n        375,\n        427,\n        6,\n        891,\n        456,\n        195,\n        980,\n        641,\n        435,\n        419,\n        1004,\n        925,\n        721,\n        538,\n        80,\n        718,\n        762,\n        679,\n        1063,\n        131,\n        937,\n        311,\n        307,\n        552,\n        485,\n        1054,\n        1043,\n        959,\n        1176,\n        840,\n        1037,\n        432,\n        699,\n        1177,\n        812,\n        424,\n        1157,\n        752,\n        360,\n        190,\n        68,\n        316,\n        693,\n        1194,\n        383,\n        384,\n        951,\n        517,\n        741,\n        1110,\n        1072,\n        671,\n        676,\n        219,\n        179,\n        1121,\n        30,\n        396,\n        1106,\n        1005,\n        585,\n        402,\n        321,\n        954,\n        943,\n        606,\n        749,\n        255,\n        859,\n        530,\n        865,\n        390,\n        395,\n        976,\n        877,\n        556,\n        268,\n        1007,\n        755,\n        957,\n        909,\n        702,\n        933,\n        509,\n        929,\n        130,\n        582,\n        289,\n        1073,\n        277,\n        756,\n        973,\n        1100,\n        394,\n        525,\n        81,\n        598,\n        100,\n        147,\n        379,\n        196,\n        644,\n        851,\n        2,\n        782,\n        351,\n        824,\n        1090,\n        56,\n        995,\n        586,\n        636,\n        919,\n        660,\n        108,\n        199,\n        1045,\n        280,\n        1091,\n        1107,\n        553,\n        757,\n        421,\n        193,\n        655,\n        653,\n        119,\n        494,\n        492,\n        1019,\n        675,\n        343,\n        1134,\n        140,\n        339,\n        692,\n        453,\n        1018,\n        438,\n        1163,\n        964,\n        571,\n        695,\n        640,\n        28,\n        696,\n        95,\n        51,\n        510,\n        164,\n        493,\n        910,\n        54,\n        1188,\n        156,\n        911,\n        998,\n        1170,\n        123,\n        838,\n        158,\n        184,\n        760,\n        401,\n        371,\n        1141,\n        974,\n        269,\n        235,\n        134,\n        770,\n        75,\n        372,\n        747,\n        227,\n        296,\n        1032,\n        374,\n        1173,\n        487,\n        97,\n        434,\n        127,\n        893,\n        1129,\n        841,\n        206,\n        591,\n        15,\n        338,\n        125,\n        361,\n        619,\n        694,\n        113,\n        366,\n        599,\n        413,\n        469,\n        404,\n        842,\n        88,\n        1155,\n        787,\n        200,\n        1092,\n        1068,\n        738,\n        714,\n        463,\n        630,\n        543,\n        913,\n        334,\n        152,\n        481,\n        1114,\n        645,\n        285,\n        462,\n        854,\n        188,\n        337,\n        267,\n        425,\n        357,\n        984,\n        717,\n        137,\n        1031,\n        121,\n        947,\n        270,\n        1131,\n        473,\n        484,\n        1086,\n        1013,\n        836,\n        570,\n        47,\n        353,\n        82,\n        1088,\n        832,\n        84,\n        501,\n        226,\n        336,\n        1075,\n        666,\n        35,\n        216,\n        781,\n        508,\n        50,\n        465,\n        491,\n        23,\n        763,\n        669,\n        926,\n        674,\n        132,\n        687,\n        236,\n        77,\n        155,\n        278,\n        350,\n        1058,\n        715,\n        1082,\n        534,\n        135,\n        467,\n        178,\n        5,\n        309,\n        397,\n        417,\n        37,\n        754,\n        670,\n        433,\n        902,\n        815,\n        524,\n        479,\n        544,\n        664,\n        79,\n        701,\n        789,\n        528,\n        748,\n        829,\n        148,\n        332,\n        495,\n        904,\n        539,\n        381,\n        745,\n        790,\n        69,\n        1061,\n        414,\n        624,\n        398,\n        727,\n        368,\n        202,\n        537,\n        174,\n        535,\n        590,\n        1162,\n        1051,\n        643,\n        978,\n        1087,\n        642,\n        602,\n        915,\n        115,\n        972,\n        779,\n        1038,\n        731,\n        358,\n        816,\n        452,\n        850,\n        0,\n        331,\n        1145,\n        833,\n        515,\n        447,\n        568,\n        382,\n        663,\n        1135,\n        139,\n        1192,\n        13,\n        589,\n        689,\n        867,\n        290,\n        1180,\n        392,\n        213,\n        124,\n        430,\n        502,\n        365,\n        722,\n        172,\n        935,\n        651,\n        318,\n        279,\n        955,\n        950,\n        151,\n        291,\n        736,\n        724,\n        266,\n        248,\n        31,\n        21,\n        883,\n        665,\n        194,\n        1184,\n        572,\n        837,\n        254,\n        284,\n        820,\n        1150,\n        418,\n        546,\n        441,\n        122,\n        1021,\n        499,\n        725,\n        662,\n        963,\n        809,\n        406,\n        712,\n        391,\n        246,\n        455,\n        1124,\n        1065,\n        85,\n        875,\n        928,\n        914,\n        783,\n        536,\n        1074,\n        766,\n        168,\n        527,\n        393,\n        356,\n        403,\n        1027,\n        53,\n        994,\n        997,\n        436,\n        700,\n        159,\n        180,\n        386,\n        860,\n        831,\n        620,\n        126,\n        87,\n        608,\n        1111,\n        169,\n        1154,\n        565,\n        713,\n        1011,\n        319,\n        208,\n        646,\n        163,\n        377,\n        523,\n        73,\n        709,\n        1069,\n        281,\n        625,\n        573,\n        1117,\n        422,\n        230,\n        490,\n        506,\n        814,\n        244,\n        231,\n        654,\n        1161,\n        512,\n        772,\n        541,\n        181,\n        7,\n        944,\n        239,\n        932,\n        627,\n        176,\n        410,\n        1098,\n        407,\n        868,\n        802,\n        737,\n        907,\n        1153,\n        141,\n        901,\n        363,\n        63,\n        19,\n        880,\n        1008,\n        661,\n        24,\n        1156,\n        1101,\n        992,\n        1143,\n        405,\n        1104,\n        1115,\n        711,\n        449,\n        794,\n        1119,\n        562,\n        440,\n        1030,\n        698,\n        811,\n        1006,\n        253,\n        683,\n        49,\n        161,\n        1070,\n        975,\n        306,\n        182,\n        979,\n        706,\n        566,\n        688,\n        1109,\n        323,\n        32,\n        1060,\n        145,\n        211,\n        1197,\n        300,\n        616,\n        526,\n        726,\n        109,\n        312,\n        817,\n        1077,\n        153,\n        183,\n        1198,\n        969,\n        996,\n        881,\n        774,\n        513,\n        1133,\n        157,\n        799,\n        505,\n        367,\n        297,\n        822,\n        1179,\n        22,\n        76,\n        1095,\n        1017,\n        900,\n        1193,\n        894,\n        1055,\n        249,\n        20,\n        225,\n        250,\n        94,\n        818,\n        1103,\n        962,\n        557,\n        103,\n        1064,\n        251,\n        310,\n        592,\n        810,\n        191,\n        953,\n        1130,\n        792,\n        968,\n        232,\n        948,\n        658,\n        588,\n        826,\n        92,\n        458,\n        771,\n        91,\n        287,\n        876,\n        369,\n        252,\n        203,\n        314,\n        1138,\n        554,\n        1169,\n        265,\n        364,\n        677,\n        480,\n        1175,\n        835,\n        796,\n        632,\n        803,\n        220,\n        256,\n        1097,\n        466,\n        615,\n        324,\n        65,\n        64,\n        1113,\n        320,\n        400,\n        460,\n        327,\n        610,\n        743,\n        863,\n        866,\n        10,\n        27,\n        853,\n        325,\n        667,\n        212,\n        102,\n        322,\n        488,\n        728,\n        259,\n        1056,\n        301,\n        555,\n        825,\n        882,\n        445,\n        105,\n        1010,\n        1009,\n        1152,\n        697,\n        171,\n        1040,\n        118,\n        165,\n        431,\n        600,\n        804,\n        245,\n        1189,\n        1020,\n        1116,\n        845,\n        1001,\n        423,\n        93,\n        14,\n        686,\n        628,\n        12,\n        1026,\n        1144,\n        46,\n        1126,\n        110,\n        283,\n        1181,\n        1012,\n        934,\n        577,\n        302,\n        373,\n        273,\n        83,\n        579,\n        229,\n        563,\n        584,\n        1166,\n        1140,\n        800,\n        601,\n        629,\n        117,\n        349,\n        128,\n        1089,\n        150,\n        807,\n        1185,\n        389,\n        74,\n        416,\n        40,\n        1035,\n        971,\n        788,\n        564,\n        1159,\n        949,\n        500,\n        732,\n        988,\n        618,\n        990,\n        930,\n        298,\n        116,\n        1118,\n        346,\n        376,\n        261,\n        861,\n        518,\n        614,\n        340,\n        1191,\n        274,\n        946,\n        1112,\n        1076,\n        173,\n        136,\n        86,\n        41,\n        742,\n        1078,\n        240,\n        1182,\n        786,\n        496,\n        547,\n        1050,\n        942,\n        903,\n        936,\n        352,\n        560,\n        1147,\n        857,\n        98,\n        977,\n        272,\n        218,\n        439,\n        347,\n        138,\n        801,\n        576,\n        830,\n        1128,\n        878,\n        982,\n        160,\n        1174,\n        288,\n        921\n    ]\n}"
  },
  {
    "path": "datainfo/single_pendulum/data_split_dict_2.json",
    "content": "{\n    \"test\": [\n        789,\n        3,\n        1071,\n        376,\n        321,\n        261,\n        523,\n        764,\n        43,\n        83,\n        51,\n        138,\n        235,\n        169,\n        1167,\n        510,\n        352,\n        737,\n        742,\n        116,\n        65,\n        866,\n        123,\n        431,\n        501,\n        544,\n        1163,\n        1069,\n        1113,\n        464,\n        559,\n        100,\n        120,\n        217,\n        391,\n        17,\n        699,\n        154,\n        750,\n        1048,\n        1001,\n        425,\n        638,\n        832,\n        1039,\n        1060,\n        1032,\n        621,\n        633,\n        982,\n        1181,\n        340,\n        664,\n        454,\n        996,\n        935,\n        718,\n        1171,\n        931,\n        1152,\n        1055,\n        1025,\n        1020,\n        571,\n        1003,\n        511,\n        1086,\n        944,\n        818,\n        330,\n        1156,\n        741,\n        724,\n        1178,\n        1075,\n        849,\n        912,\n        372,\n        1146,\n        1052,\n        736,\n        1162,\n        1044,\n        279,\n        355,\n        665,\n        361,\n        48,\n        472,\n        483,\n        363,\n        1147,\n        336,\n        1076,\n        867,\n        778,\n        652,\n        952,\n        745,\n        56,\n        1191,\n        549,\n        1028,\n        911,\n        1114,\n        761,\n        1042,\n        805,\n        882,\n        324,\n        1190,\n        73,\n        434,\n        515,\n        631,\n        346,\n        739,\n        173,\n        187,\n        115\n    ],\n    \"val\": [\n        954,\n        706,\n        296,\n        375,\n        625,\n        121,\n        943,\n        531,\n        19,\n        374,\n        491,\n        821,\n        679,\n        96,\n        942,\n        185,\n        983,\n        537,\n        951,\n        428,\n        902,\n        52,\n        605,\n        595,\n        21,\n        1158,\n        435,\n        444,\n        825,\n        746,\n        215,\n        986,\n        1175,\n        1037,\n        701,\n        1108,\n        389,\n        657,\n        548,\n        317,\n        1109,\n        475,\n        685,\n        533,\n        25,\n        222,\n        107,\n        237,\n        768,\n        186,\n        20,\n        102,\n        104,\n        246,\n        89,\n        1151,\n        524,\n        113,\n        1029,\n        418,\n        393,\n        1046,\n        1117,\n        9,\n        570,\n        1101,\n        525,\n        1160,\n        467,\n        164,\n        513,\n        150,\n        910,\n        476,\n        505,\n        64,\n        474,\n        928,\n        196,\n        349,\n        1149,\n        269,\n        68,\n        518,\n        1099,\n        287,\n        36,\n        859,\n        536,\n        530,\n        698,\n        294,\n        671,\n        997,\n        804,\n        1085,\n        917,\n        49,\n        208,\n        647,\n        191,\n        461,\n        968,\n        314,\n        822,\n        540,\n        93,\n        918,\n        1194,\n        63,\n        1000,\n        690,\n        585,\n        231,\n        704,\n        8,\n        74,\n        310,\n        507,\n        88\n    ],\n    \"train\": [\n        689,\n        907,\n        403,\n        285,\n        1073,\n        796,\n        316,\n        140,\n        1183,\n        132,\n        861,\n        1047,\n        167,\n        556,\n        985,\n        999,\n        1185,\n        891,\n        984,\n        199,\n        717,\n        40,\n        937,\n        354,\n        405,\n        271,\n        1155,\n        1125,\n        262,\n        900,\n        292,\n        304,\n        151,\n        909,\n        397,\n        580,\n        820,\n        646,\n        667,\n        1057,\n        183,\n        684,\n        760,\n        450,\n        1159,\n        726,\n        198,\n        747,\n        880,\n        1012,\n        335,\n        719,\n        39,\n        913,\n        84,\n        168,\n        207,\n        1051,\n        396,\n        618,\n        1088,\n        834,\n        567,\n        146,\n        923,\n        814,\n        921,\n        629,\n        299,\n        1139,\n        45,\n        734,\n        216,\n        758,\n        302,\n        424,\n        128,\n        60,\n        406,\n        85,\n        447,\n        205,\n        710,\n        547,\n        566,\n        32,\n        156,\n        18,\n        892,\n        105,\n        1170,\n        239,\n        659,\n        1061,\n        797,\n        76,\n        248,\n        541,\n        641,\n        783,\n        969,\n        1199,\n        508,\n        94,\n        58,\n        394,\n        639,\n        863,\n        479,\n        819,\n        1034,\n        487,\n        947,\n        862,\n        283,\n        305,\n        327,\n        438,\n        744,\n        975,\n        386,\n        723,\n        202,\n        756,\n        979,\n        236,\n        1092,\n        1094,\n        774,\n        343,\n        219,\n        920,\n        360,\n        945,\n        601,\n        810,\n        998,\n        568,\n        932,\n        197,\n        603,\n        830,\n        527,\n        808,\n        940,\n        46,\n        792,\n        875,\n        816,\n        1005,\n        506,\n        916,\n        634,\n        976,\n        244,\n        1009,\n        2,\n        1198,\n        290,\n        766,\n        1182,\n        840,\n        417,\n        182,\n        249,\n        210,\n        443,\n        623,\n        307,\n        590,\n        1043,\n        592,\n        925,\n        155,\n        858,\n        675,\n        90,\n        427,\n        442,\n        848,\n        846,\n        1093,\n        692,\n        1030,\n        1131,\n        35,\n        255,\n        111,\n        1166,\n        469,\n        1016,\n        22,\n        614,\n        377,\n        1018,\n        26,\n        195,\n        499,\n        165,\n        37,\n        333,\n        883,\n        899,\n        247,\n        906,\n        981,\n        288,\n        365,\n        616,\n        693,\n        1070,\n        7,\n        1135,\n        441,\n        1179,\n        874,\n        188,\n        992,\n        851,\n        347,\n        1054,\n        456,\n        258,\n        991,\n        529,\n        1110,\n        1066,\n        970,\n        69,\n        370,\n        457,\n        735,\n        1137,\n        546,\n        1050,\n        13,\n        879,\n        312,\n        348,\n        872,\n        23,\n        226,\n        611,\n        1130,\n        678,\n        445,\n        57,\n        557,\n        569,\n        853,\n        668,\n        532,\n        272,\n        806,\n        379,\n        977,\n        974,\n        223,\n        266,\n        385,\n        127,\n        971,\n        31,\n        562,\n        1112,\n        1058,\n        624,\n        753,\n        798,\n        1161,\n        300,\n        281,\n        980,\n        552,\n        253,\n        42,\n        142,\n        228,\n        486,\n        573,\n        436,\n        1103,\n        407,\n        459,\n        781,\n        1023,\n        1142,\n        780,\n        1077,\n        225,\n        175,\n        517,\n        714,\n        163,\n        29,\n        1157,\n        622,\n        14,\n        6,\n        961,\n        473,\n        274,\n        660,\n        157,\n        329,\n        295,\n        1017,\n        241,\n        617,\n        1065,\n        1021,\n        158,\n        731,\n        828,\n        811,\n        1064,\n        227,\n        378,\n        888,\n        565,\n        1063,\n        586,\n        1116,\n        411,\n        1098,\n        707,\n        232,\n        229,\n        87,\n        1100,\n        423,\n        833,\n        1041,\n        838,\n        1134,\n        612,\n        133,\n        896,\n        214,\n        109,\n        578,\n        946,\n        887,\n        645,\n        398,\n        1121,\n        539,\n        757,\n        670,\n        212,\n        211,\n        687,\n        572,\n        865,\n        1196,\n        700,\n        80,\n        584,\n        149,\n        588,\n        662,\n        351,\n        1136,\n        972,\n        542,\n        708,\n        325,\n        521,\n        455,\n        1111,\n        1096,\n        957,\n        432,\n        306,\n        786,\n        655,\n        1033,\n        458,\n        591,\n        850,\n        1193,\n        134,\n        380,\n        1174,\n        1148,\n        1144,\n        1040,\n        1010,\n        962,\n        200,\n        99,\n        683,\n        876,\n        384,\n        282,\n        308,\n        482,\n        1192,\n        1123,\n        857,\n        357,\n        894,\n        583,\n        669,\n        117,\n        615,\n        250,\n        785,\n        466,\n        564,\n        139,\n        575,\n        1090,\n        607,\n        50,\n        415,\n        1053,\n        1097,\n        868,\n        950,\n        66,\n        498,\n        1180,\n        1118,\n        174,\n        152,\n        953,\n        193,\n        967,\n        898,\n        1084,\n        77,\n        1019,\n        847,\n        53,\n        705,\n        776,\n        345,\n        649,\n        1056,\n        268,\n        581,\n        201,\n        490,\n        331,\n        318,\n        635,\n        619,\n        672,\n        519,\n        598,\n        27,\n        632,\n        108,\n        492,\n        534,\n        126,\n        59,\n        276,\n        812,\n        0,\n        251,\n        855,\n        484,\n        122,\n        448,\n        270,\n        1124,\n        596,\n        267,\n        416,\n        941,\n        835,\n        949,\n        964,\n        749,\n        930,\n        800,\n        277,\n        864,\n        356,\n        70,\n        15,\n        286,\n        71,\n        1013,\n        33,\n        1102,\n        1127,\n        460,\n        439,\n        897,\n        613,\n        332,\n        419,\n        172,\n        784,\n        504,\n        703,\n        933,\n        12,\n        844,\n        112,\n        606,\n        772,\n        948,\n        129,\n        4,\n        593,\n        408,\n        234,\n        993,\n        666,\n        637,\n        179,\n        512,\n        965,\n        839,\n        966,\n        110,\n        676,\n        463,\n        751,\n        421,\n        815,\n        465,\n        410,\n        1078,\n        794,\n        92,\n        1189,\n        1015,\n        485,\n        1067,\n        955,\n        233,\n        597,\n        1195,\n        119,\n        1059,\n        576,\n        135,\n        278,\n        1133,\n        440,\n        1035,\n        759,\n        147,\n        1107,\n        929,\n        344,\n        47,\n        245,\n        252,\n        914,\n        608,\n        654,\n        256,\n        369,\n        430,\n        885,\n        257,\n        339,\n        769,\n        145,\n        97,\n        680,\n        206,\n        323,\n        807,\n        1022,\n        752,\n        1138,\n        554,\n        787,\n        242,\n        604,\n        555,\n        738,\n        446,\n        721,\n        809,\n        171,\n        359,\n        106,\n        711,\n        130,\n        642,\n        322,\n        190,\n        1014,\n        488,\n        802,\n        362,\n        673,\n        178,\n        677,\n        904,\n        653,\n        1083,\n        788,\n        1105,\n        630,\n        799,\n        328,\n        775,\n        713,\n        1045,\n        350,\n        409,\n        387,\n        563,\n        1188,\n        651,\n        702,\n        915,\n        1122,\n        189,\n        919,\n        194,\n        480,\n        243,\n        101,\n        162,\n        334,\n        1106,\n        755,\n        905,\n        137,\n        813,\n        298,\n        452,\n        889,\n        170,\n        371,\n        767,\n        716,\n        79,\n        852,\n        326,\n        218,\n        373,\n        388,\n        644,\n        827,\n        62,\n        284,\n        535,\n        836,\n        574,\n        886,\n        238,\n        41,\n        681,\n        91,\n        643,\n        516,\n        153,\n        543,\n        1141,\n        901,\n        990,\n        648,\n        520,\n        95,\n        1049,\n        893,\n        658,\n        319,\n        661,\n        341,\n        311,\n        765,\n        264,\n        926,\n        309,\n        551,\n        881,\n        136,\n        1132,\n        1095,\n        958,\n        762,\n        453,\n        577,\n        1008,\n        550,\n        34,\n        963,\n        1,\n        729,\n        1024,\n        289,\n        10,\n        826,\n        301,\n        1074,\n        1087,\n        390,\n        1145,\n        166,\n        55,\n        1091,\n        640,\n        691,\n        1140,\n        790,\n        779,\n        1164,\n        934,\n        81,\n        28,\n        471,\n        514,\n        273,\n        470,\n        903,\n        1082,\n        496,\n        1129,\n        67,\n        777,\n        1002,\n        824,\n        939,\n        1081,\n        358,\n        1173,\n        209,\n        478,\n        177,\n        579,\n        1026,\n        1080,\n        493,\n        987,\n        694,\n        627,\n        1154,\n        141,\n        626,\n        1104,\n        922,\n        594,\n        600,\n        791,\n        204,\n        1143,\n        873,\n        1011,\n        220,\n        395,\n        620,\n        582,\n        1187,\n        221,\n        770,\n        1115,\n        973,\n        181,\n        1165,\n        803,\n        728,\n        842,\n        1120,\n        320,\n        11,\n        650,\n        1184,\n        682,\n        412,\n        192,\n        636,\n        545,\n        230,\n        743,\n        1089,\n        956,\n        763,\n        730,\n        338,\n        560,\n        368,\n        1169,\n        841,\n        890,\n        960,\n        748,\n        727,\n        259,\n        414,\n        280,\n        494,\n        54,\n        522,\n        402,\n        184,\n        754,\n        392,\n        609,\n        30,\n        878,\n        451,\n        477,\n        148,\n        989,\n        587,\n        782,\n        801,\n        367,\n        413,\n        240,\n        720,\n        1168,\n        161,\n        1197,\n        263,\n        526,\n        732,\n        254,\n        528,\n        1172,\n        854,\n        433,\n        924,\n        553,\n        159,\n        114,\n        293,\n        1119,\n        176,\n        449,\n        337,\n        843,\n        686,\n        725,\n        869,\n        399,\n        871,\n        131,\n        38,\n        663,\n        829,\n        697,\n        1079,\n        884,\n        1186,\n        426,\n        988,\n        1031,\n        558,\n        180,\n        364,\n        72,\n        98,\n        589,\n        837,\n        877,\n        599,\n        86,\n        715,\n        695,\n        712,\n        437,\n        561,\n        265,\n        610,\n        500,\n        160,\n        1007,\n        1150,\n        303,\n        696,\n        856,\n        503,\n        429,\n        823,\n        1153,\n        1027,\n        489,\n        538,\n        495,\n        275,\n        383,\n        817,\n        144,\n        468,\n        366,\n        297,\n        509,\n        722,\n        927,\n        44,\n        795,\n        481,\n        674,\n        1128,\n        1004,\n        1177,\n        709,\n        995,\n        400,\n        656,\n        342,\n        870,\n        1068,\n        82,\n        125,\n        936,\n        260,\n        124,\n        908,\n        404,\n        353,\n        771,\n        143,\n        938,\n        733,\n        401,\n        382,\n        1072,\n        118,\n        1038,\n        502,\n        773,\n        224,\n        78,\n        740,\n        978,\n        959,\n        24,\n        291,\n        5,\n        75,\n        628,\n        602,\n        1126,\n        213,\n        1036,\n        497,\n        1176,\n        420,\n        61,\n        462,\n        831,\n        16,\n        845,\n        688,\n        793,\n        860,\n        203,\n        313,\n        1006,\n        103,\n        422,\n        994,\n        895,\n        1062,\n        315,\n        381\n    ]\n}"
  },
  {
    "path": "datainfo/single_pendulum/data_split_dict_3.json",
    "content": "{\n    \"test\": [\n        1194,\n        644,\n        1184,\n        23,\n        694,\n        620,\n        1067,\n        1157,\n        895,\n        1155,\n        486,\n        883,\n        555,\n        1153,\n        210,\n        1152,\n        1065,\n        942,\n        771,\n        1121,\n        283,\n        737,\n        642,\n        695,\n        86,\n        319,\n        539,\n        597,\n        835,\n        404,\n        64,\n        1096,\n        221,\n        157,\n        14,\n        634,\n        1161,\n        483,\n        1035,\n        571,\n        677,\n        773,\n        92,\n        90,\n        243,\n        850,\n        1167,\n        601,\n        41,\n        1181,\n        840,\n        136,\n        704,\n        181,\n        990,\n        987,\n        129,\n        254,\n        583,\n        546,\n        432,\n        213,\n        1109,\n        668,\n        334,\n        572,\n        58,\n        689,\n        475,\n        834,\n        1093,\n        718,\n        790,\n        1038,\n        862,\n        1172,\n        893,\n        528,\n        444,\n        1013,\n        278,\n        73,\n        199,\n        748,\n        274,\n        910,\n        808,\n        874,\n        793,\n        968,\n        551,\n        63,\n        616,\n        87,\n        326,\n        131,\n        31,\n        798,\n        1071,\n        310,\n        474,\n        308,\n        813,\n        975,\n        1125,\n        1107,\n        963,\n        392,\n        479,\n        1128,\n        531,\n        960,\n        26,\n        134,\n        1189,\n        970,\n        757,\n        267,\n        1114,\n        487\n    ],\n    \"val\": [\n        710,\n        822,\n        925,\n        35,\n        46,\n        250,\n        829,\n        122,\n        293,\n        602,\n        878,\n        1029,\n        882,\n        232,\n        911,\n        349,\n        556,\n        295,\n        1190,\n        765,\n        678,\n        1079,\n        856,\n        467,\n        763,\n        33,\n        227,\n        734,\n        949,\n        972,\n        1145,\n        1158,\n        522,\n        637,\n        901,\n        852,\n        965,\n        1047,\n        4,\n        204,\n        159,\n        423,\n        1097,\n        36,\n        419,\n        581,\n        429,\n        297,\n        426,\n        649,\n        354,\n        1148,\n        277,\n        870,\n        812,\n        530,\n        298,\n        431,\n        132,\n        603,\n        353,\n        1051,\n        825,\n        175,\n        696,\n        570,\n        375,\n        645,\n        390,\n        69,\n        247,\n        460,\n        554,\n        923,\n        446,\n        1147,\n        163,\n        346,\n        897,\n        459,\n        683,\n        659,\n        208,\n        198,\n        891,\n        383,\n        671,\n        488,\n        1170,\n        455,\n        1024,\n        1115,\n        269,\n        55,\n        214,\n        772,\n        615,\n        540,\n        1086,\n        640,\n        379,\n        745,\n        363,\n        655,\n        611,\n        934,\n        514,\n        756,\n        43,\n        124,\n        45,\n        1002,\n        1119,\n        1074,\n        722,\n        954,\n        680,\n        123,\n        272,\n        1098\n    ],\n    \"train\": [\n        464,\n        1062,\n        890,\n        22,\n        660,\n        427,\n        290,\n        167,\n        469,\n        1159,\n        341,\n        712,\n        560,\n        457,\n        77,\n        952,\n        639,\n        462,\n        1185,\n        1195,\n        654,\n        135,\n        83,\n        797,\n        730,\n        372,\n        650,\n        218,\n        225,\n        787,\n        285,\n        265,\n        343,\n        279,\n        452,\n        1028,\n        397,\n        951,\n        948,\n        421,\n        855,\n        625,\n        158,\n        846,\n        1004,\n        1009,\n        753,\n        72,\n        466,\n        559,\n        914,\n        80,\n        591,\n        738,\n        142,\n        1120,\n        847,\n        408,\n        708,\n        973,\n        552,\n        236,\n        364,\n        1186,\n        207,\n        606,\n        1103,\n        338,\n        741,\n        622,\n        282,\n        178,\n        991,\n        263,\n        1054,\n        1112,\n        612,\n        586,\n        1072,\n        1193,\n        371,\n        1129,\n        480,\n        66,\n        1176,\n        781,\n        691,\n        507,\n        735,\n        853,\n        91,\n        575,\n        386,\n        192,\n        928,\n        67,\n        335,\n        785,\n        656,\n        585,\n        220,\n        789,\n        374,\n        347,\n        706,\n        231,\n        196,\n        188,\n        752,\n        681,\n        348,\n        324,\n        1177,\n        309,\n        151,\n        727,\n        929,\n        525,\n        688,\n        100,\n        95,\n        1141,\n        1124,\n        352,\n        519,\n        1090,\n        624,\n        1075,\n        623,\n        266,\n        351,\n        219,\n        723,\n        1133,\n        470,\n        1085,\n        75,\n        38,\n        1150,\n        510,\n        1031,\n        998,\n        800,\n        8,\n        641,\n        676,\n        370,\n        104,\n        288,\n        766,\n        415,\n        286,\n        292,\n        521,\n        411,\n        42,\n        1182,\n        443,\n        1110,\n        1092,\n        811,\n        873,\n        1073,\n        731,\n        842,\n        977,\n        350,\n        561,\n        788,\n        424,\n        1179,\n        312,\n        876,\n        906,\n        141,\n        321,\n        228,\n        564,\n        294,\n        791,\n        414,\n        922,\n        76,\n        118,\n        328,\n        138,\n        1166,\n        1156,\n        356,\n        628,\n        714,\n        841,\n        777,\n        1142,\n        799,\n        803,\n        103,\n        875,\n        844,\n        638,\n        1063,\n        1055,\n        726,\n        953,\n        1113,\n        313,\n        518,\n        315,\n        449,\n        935,\n        1160,\n        65,\n        881,\n        482,\n        6,\n        633,\n        784,\n        296,\n        498,\n        10,\n        1081,\n        337,\n        331,\n        865,\n        867,\n        434,\n        1169,\n        380,\n        605,\n        394,\n        154,\n        889,\n        884,\n        393,\n        412,\n        143,\n        217,\n        500,\n        587,\n        783,\n        815,\n        1052,\n        945,\n        535,\n        262,\n        961,\n        1007,\n        635,\n        12,\n        284,\n        61,\n        1089,\n        767,\n        672,\n        1076,\n        1197,\n        607,\n        717,\n        1117,\n        327,\n        999,\n        866,\n        1149,\n        0,\n        399,\n        1010,\n        345,\n        230,\n        994,\n        667,\n        121,\n        1036,\n        49,\n        1144,\n        1199,\n        1101,\n        992,\n        489,\n        451,\n        249,\n        533,\n        979,\n        927,\n        137,\n        666,\n        595,\n        496,\n        770,\n        959,\n        212,\n        534,\n        1099,\n        978,\n        1100,\n        197,\n        89,\n        697,\n        686,\n        171,\n        913,\n        1017,\n        477,\n        1146,\n        211,\n        558,\n        827,\n        1108,\n        201,\n        156,\n        619,\n        1039,\n        179,\n        915,\n        1088,\n        679,\n        437,\n        169,\n        657,\n        96,\n        27,\n        246,\n        502,\n        627,\n        684,\n        473,\n        647,\n        1005,\n        53,\n        739,\n        1046,\n        303,\n        485,\n        851,\n        29,\n        955,\n        373,\n        576,\n        699,\n        172,\n        15,\n        1131,\n        287,\n        930,\n        725,\n        589,\n        238,\n        653,\n        148,\n        454,\n        304,\n        1171,\n        511,\n        59,\n        318,\n        170,\n        84,\n        97,\n        857,\n        1162,\n        48,\n        947,\n        112,\n        481,\n        252,\n        703,\n        941,\n        1026,\n        617,\n        981,\n        837,\n        1021,\n        16,\n        1042,\n        162,\n        664,\n        838,\n        760,\n        517,\n        1059,\n        251,\n        1033,\n        420,\n        1154,\n        886,\n        71,\n        1198,\n        325,\n        1104,\n        257,\n        1015,\n        299,\n        830,\n        1084,\n        614,\n        848,\n        506,\n        721,\n        966,\n        849,\n        715,\n        596,\n        1018,\n        107,\n        888,\n        759,\n        418,\n        305,\n        669,\n        504,\n        1014,\n        608,\n        1027,\n        536,\n        368,\n        19,\n        233,\n        1122,\n        224,\n        796,\n        1020,\n        447,\n        234,\n        698,\n        11,\n        545,\n        956,\n        750,\n        106,\n        216,\n        907,\n        34,\n        1135,\n        242,\n        899,\n        200,\n        82,\n        818,\n        450,\n        1140,\n        476,\n        209,\n        21,\n        402,\n        1053,\n        161,\n        248,\n        113,\n        387,\n        357,\n        819,\n        543,\n        833,\n        32,\n        1095,\n        1048,\n        916,\n        1087,\n        40,\n        366,\n        17,\n        396,\n        578,\n        1006,\n        439,\n        1080,\n        68,\n        983,\n        355,\n        682,\n        445,\n        168,\n        165,\n        648,\n        9,\n        898,\n        743,\n        39,\n        259,\n        119,\n        316,\n        1061,\n        950,\n        25,\n        472,\n        574,\n        782,\n        405,\n        513,\n        971,\n        1056,\n        456,\n        858,\n        674,\n        145,\n        289,\n        57,\n        986,\n        786,\n        1044,\n        1180,\n        709,\n        416,\n        821,\n        108,\n        1019,\n        940,\n        693,\n        828,\n        744,\n        646,\n        805,\n        177,\n        503,\n        239,\n        905,\n        2,\n        859,\n        526,\n        824,\n        764,\n        920,\n        317,\n        871,\n        562,\n        417,\n        1049,\n        631,\n        300,\n        378,\n        610,\n        206,\n        588,\n        174,\n        1043,\n        407,\n        149,\n        912,\n        621,\n        1008,\n        193,\n        189,\n        1139,\n        673,\n        565,\n        817,\n        155,\n        114,\n        816,\n        361,\n        806,\n        458,\n        413,\n        340,\n        362,\n        700,\n        253,\n        794,\n        516,\n        685,\n        845,\n        705,\n        832,\n        1187,\n        900,\n        1078,\n        754,\n        367,\n        401,\n        976,\n        755,\n        185,\n        969,\n        1041,\n        180,\n        1105,\n        768,\n        468,\n        652,\n        944,\n        854,\n        974,\n        629,\n        144,\n        573,\n        110,\n        1130,\n        742,\n        1060,\n        775,\n        153,\n        186,\n        724,\n        301,\n        60,\n        924,\n        520,\n        505,\n        936,\n        538,\n        939,\n        579,\n        1037,\n        1058,\n        860,\n        1050,\n        461,\n        377,\n        807,\n        1064,\n        99,\n        256,\n        1057,\n        448,\n        115,\n        442,\n        78,\n        281,\n        917,\n        885,\n        557,\n        932,\n        400,\n        872,\n        701,\n        215,\n        747,\n        497,\n        388,\n        1034,\n        339,\n        342,\n        30,\n        919,\n        344,\n        670,\n        463,\n        365,\n        1003,\n        1016,\n        270,\n        780,\n        120,\n        1173,\n        93,\n        894,\n        187,\n        320,\n        663,\n        173,\n        261,\n        762,\n        636,\n        191,\n        140,\n        264,\n        194,\n        746,\n        260,\n        769,\n        880,\n        728,\n        235,\n        478,\n        964,\n        903,\n        74,\n        988,\n        809,\n        1025,\n        1143,\n        594,\n        1188,\n        1094,\n        736,\n        491,\n        758,\n        147,\n        879,\n        329,\n        1196,\n        861,\n        1138,\n        176,\n        205,\n        391,\n        490,\n        1164,\n        1083,\n        314,\n        222,\n        1132,\n        692,\n        598,\n        609,\n        938,\n        687,\n        302,\n        550,\n        237,\n        826,\n        887,\n        599,\n        150,\n        957,\n        273,\n        139,\n        190,\n        863,\n        453,\n        720,\n        1151,\n        509,\n        98,\n        166,\n        152,\n        410,\n        658,\n        1000,\n        495,\n        702,\n        529,\n        690,\n        593,\n        582,\n        802,\n        425,\n        164,\n        422,\n        240,\n        512,\n        776,\n        85,\n        823,\n        569,\n        358,\n        406,\n        613,\n        376,\n        931,\n        403,\n        716,\n        630,\n        465,\n        255,\n        661,\n        1163,\n        1030,\n        79,\n        184,\n        761,\n        332,\n        563,\n        962,\n        566,\n        102,\n        567,\n        109,\n        943,\n        732,\n        541,\n        523,\n        779,\n        37,\n        1123,\n        306,\n        592,\n        291,\n        713,\n        493,\n        1118,\n        117,\n        1192,\n        804,\n        933,\n        24,\n        183,\n        384,\n        675,\n        101,\n        792,\n        94,\n        896,\n        1070,\n        864,\n        241,\n        146,\n        892,\n        5,\n        995,\n        1127,\n        105,\n        544,\n        577,\n        1174,\n        751,\n        111,\n        1040,\n        385,\n        542,\n        1178,\n        749,\n        982,\n        707,\n        1011,\n        1069,\n        268,\n        1134,\n        1045,\n        333,\n        28,\n        133,\n        436,\n        229,\n        435,\n        868,\n        604,\n        711,\n        276,\n        548,\n        223,\n        996,\n        18,\n        909,\n        428,\n        1191,\n        430,\n        869,\n        719,\n        778,\n        508,\n        1102,\n        381,\n        937,\n        441,\n        993,\n        44,\n        820,\n        651,\n        1082,\n        1032,\n        926,\n        665,\n        618,\n        471,\n        382,\n        1068,\n        182,\n        580,\n        81,\n        814,\n        389,\n        740,\n        733,\n        3,\n        568,\n        130,\n        126,\n        438,\n        336,\n        195,\n        271,\n        369,\n        311,\n        600,\n        398,\n        662,\n        395,\n        359,\n        1116,\n        322,\n        160,\n        323,\n        501,\n        1066,\n        527,\n        88,\n        729,\n        1126,\n        985,\n        494,\n        70,\n        902,\n        127,\n        47,\n        1136,\n        1168,\n        56,\n        877,\n        1,\n        839,\n        795,\n        330,\n        484,\n        52,\n        433,\n        532,\n        1106,\n        632,\n        275,\n        1137,\n        1012,\n        774,\n        51,\n        409,\n        1091,\n        584,\n        643,\n        499,\n        7,\n        843,\n        1175,\n        1022,\n        1165,\n        280,\n        810,\n        245,\n        958,\n        967,\n        836,\n        908,\n        547,\n        125,\n        202,\n        226,\n        360,\n        62,\n        801,\n        831,\n        997,\n        553,\n        989,\n        258,\n        128,\n        1183,\n        116,\n        626,\n        1111,\n        54,\n        1001,\n        549,\n        537,\n        20,\n        980,\n        244,\n        307,\n        515,\n        1023,\n        1077,\n        984,\n        492,\n        13,\n        50,\n        590,\n        440,\n        921,\n        904,\n        918,\n        203,\n        946,\n        524\n    ]\n}"
  },
  {
    "path": "datainfo/swingstick_magnetic/data_split_dict_1.json",
    "content": "{\n    \"test\": [\n        2,\n        18,\n        4\n    ],\n    \"val\": [\n        3,\n        8\n    ],\n    \"train\": [\n        19,\n        5,\n        13,\n        9,\n        20,\n        21,\n        11,\n        12,\n        0,\n        16,\n        1,\n        17,\n        6,\n        10,\n        7,\n        14,\n        15\n    ]\n}"
  },
  {
    "path": "datainfo/swingstick_magnetic/data_split_dict_2.json",
    "content": "{\n    \"test\": [\n        20,\n        2,\n        1\n    ],\n    \"val\": [\n        5,\n        11\n    ],\n    \"train\": [\n        18,\n        21,\n        13,\n        10,\n        4,\n        17,\n        7,\n        15,\n        6,\n        19,\n        12,\n        0,\n        14,\n        3,\n        16,\n        8,\n        9\n    ]\n}"
  },
  {
    "path": "datainfo/swingstick_magnetic/data_split_dict_3.json",
    "content": "{\n    \"test\": [\n        17,\n        18,\n        7\n    ],\n    \"val\": [\n        11,\n        4\n    ],\n    \"train\": [\n        1,\n        13,\n        16,\n        5,\n        10,\n        6,\n        19,\n        12,\n        14,\n        3,\n        8,\n        20,\n        21,\n        0,\n        9,\n        2,\n        15\n    ]\n}"
  },
  {
    "path": "datainfo/swingstick_non_magnetic/data_split_dict_1.json",
    "content": "{\n    \"test\": [\n        48,\n        60,\n        57,\n        63,\n        15,\n        32,\n        8,\n        72,\n        17\n    ],\n    \"val\": [\n        77,\n        0,\n        55,\n        49,\n        3,\n        62,\n        12,\n        26\n    ],\n    \"train\": [\n        69,\n        20,\n        80,\n        10,\n        4,\n        59,\n        67,\n        81,\n        5,\n        2,\n        37,\n        78,\n        83,\n        58,\n        71,\n        19,\n        38,\n        73,\n        25,\n        79,\n        70,\n        30,\n        9,\n        36,\n        51,\n        52,\n        53,\n        16,\n        54,\n        23,\n        47,\n        21,\n        7,\n        39,\n        11,\n        6,\n        45,\n        74,\n        50,\n        18,\n        65,\n        42,\n        44,\n        22,\n        75,\n        35,\n        31,\n        28,\n        14,\n        33,\n        76,\n        46,\n        27,\n        64,\n        43,\n        24,\n        56,\n        68,\n        66,\n        41,\n        61,\n        82,\n        1,\n        40,\n        13,\n        29,\n        34\n    ]\n}"
  },
  {
    "path": "datainfo/swingstick_non_magnetic/data_split_dict_2.json",
    "content": "{\n    \"test\": [\n        4,\n        27,\n        32,\n        39,\n        21,\n        46,\n        10,\n        11,\n        7\n    ],\n    \"val\": [\n        64,\n        56,\n        47,\n        65,\n        50,\n        55,\n        20,\n        74\n    ],\n    \"train\": [\n        0,\n        31,\n        70,\n        6,\n        49,\n        41,\n        2,\n        67,\n        17,\n        13,\n        62,\n        9,\n        5,\n        57,\n        19,\n        25,\n        78,\n        18,\n        37,\n        79,\n        45,\n        51,\n        83,\n        16,\n        69,\n        71,\n        36,\n        66,\n        12,\n        61,\n        30,\n        54,\n        38,\n        40,\n        22,\n        42,\n        72,\n        26,\n        28,\n        63,\n        53,\n        43,\n        23,\n        44,\n        77,\n        8,\n        48,\n        60,\n        52,\n        1,\n        14,\n        15,\n        82,\n        35,\n        81,\n        33,\n        68,\n        76,\n        24,\n        58,\n        73,\n        59,\n        29,\n        80,\n        3,\n        75,\n        34\n    ]\n}"
  },
  {
    "path": "datainfo/swingstick_non_magnetic/data_split_dict_3.json",
    "content": "{\n    \"test\": [\n        8,\n        74,\n        60,\n        77,\n        47,\n        16,\n        69,\n        75,\n        30\n    ],\n    \"val\": [\n        68,\n        73,\n        24,\n        29,\n        70,\n        33,\n        78,\n        1\n    ],\n    \"train\": [\n        3,\n        15,\n        14,\n        22,\n        21,\n        51,\n        46,\n        52,\n        35,\n        5,\n        41,\n        56,\n        40,\n        54,\n        62,\n        57,\n        53,\n        7,\n        64,\n        26,\n        45,\n        58,\n        11,\n        18,\n        12,\n        32,\n        67,\n        9,\n        34,\n        20,\n        44,\n        43,\n        80,\n        13,\n        31,\n        39,\n        66,\n        6,\n        23,\n        82,\n        28,\n        36,\n        25,\n        27,\n        61,\n        38,\n        83,\n        17,\n        76,\n        63,\n        2,\n        37,\n        48,\n        10,\n        4,\n        49,\n        42,\n        0,\n        79,\n        81,\n        72,\n        59,\n        55,\n        65,\n        71,\n        19,\n        50\n    ]\n}"
  },
  {
    "path": "dataset.py",
    "content": "\nimport os\nimport sys\nimport json\nimport glob\nimport torch\nimport itertools\nimport numpy as np\nfrom PIL import Image\nfrom scipy import misc\nfrom torch.utils.data import Dataset, DataLoader\nfrom torchvision import datasets, transforms\n\nclass NeuralPhysDataset(Dataset):\n    def __init__(self, data_filepath, flag, seed, object_name=\"double_pendulum\"):\n        self.seed = seed\n        self.flag = flag\n        self.object_name = object_name\n        self.data_filepath = data_filepath\n        self.all_filelist = self.get_all_filelist()\n\n    def get_all_filelist(self):\n        filelist = []\n        obj_filepath = os.path.join(self.data_filepath, self.object_name)\n        # get the video ids based on training or testing data\n        with open(os.path.join('../datainfo', self.object_name, f'data_split_dict_{self.seed}.json'), 'r') as file:\n            seq_dict = json.load(file)\n        vid_list = seq_dict[self.flag]\n\n        # go through all the selected videos and get the triplets: input(t, t+1), output(t+2)\n        for vid_idx in vid_list:\n            seq_filepath = os.path.join(obj_filepath, str(vid_idx))\n            num_frames = len(os.listdir(seq_filepath))\n            suf = os.listdir(seq_filepath)[0].split('.')[-1]\n            for p_frame in range(num_frames - 3):\n                par_list = []\n                for p in range(4):\n                    par_list.append(os.path.join(seq_filepath, str(p_frame + p) + '.' + suf))\n                filelist.append(par_list)\n        return filelist\n\n    def __len__(self):\n        return len(self.all_filelist)\n\n    # 0, 1 -> 2, 3\n    def __getitem__(self, idx):\n        par_list = self.all_filelist[idx]\n        data = []\n        for i in range(2):\n            data.append(self.get_data(par_list[i])) # 0, 1\n        data = torch.cat(data, 2)\n        target = []\n        target.append(self.get_data(par_list[-2])) # 2\n        target.append(self.get_data(par_list[-1])) # 3\n        target = torch.cat(target, 2)\n        filepath = '_'.join(par_list[0].split('/')[-2:])\n        return data, target, filepath\n\n    def get_data(self, filepath):\n        data = Image.open(filepath)\n        data = data.resize((128, 128))\n        data = np.array(data)\n        data = torch.tensor(data / 255.0)\n        data = data.permute(2, 0, 1).float()\n        return data\n\n\n\nclass NeuralPhysLatentDynamicsDataset(Dataset):\n    def __init__(self, data_filepath, flag, seed, object_name=\"double_pendulum\"):\n        self.seed = seed\n        self.flag = flag\n        self.object_name = object_name\n        self.data_filepath = data_filepath\n        self.all_filelist = self.get_all_filelist()\n\n    def get_all_filelist(self):\n        filelist = []\n        obj_filepath = os.path.join(self.data_filepath, self.object_name)\n        # get the video ids based on training or testing data\n        with open(os.path.join('../datainfo', self.object_name, f'data_split_dict_{self.seed}.json'), 'r') as file:\n            seq_dict = json.load(file)\n        vid_list = seq_dict[self.flag]\n\n        # go through all the selected videos and get the triplets: input(t, t+1), output(t+2)\n        for vid_idx in vid_list:\n            seq_filepath = os.path.join(obj_filepath, str(vid_idx))\n            num_frames = len(os.listdir(seq_filepath))\n            suf = os.listdir(seq_filepath)[0].split('.')[-1]\n            for p_frame in range(num_frames - 5):\n                par_list = []\n                for p in range(6):\n                    par_list.append(os.path.join(seq_filepath, str(p_frame + p) + '.' + suf))\n                filelist.append(par_list)\n        return filelist\n\n    def __len__(self):\n        return len(self.all_filelist)\n\n    # 0, 1 -> 2, 3\n    def __getitem__(self, idx):\n        par_list = self.all_filelist[idx]\n        data = []\n        for i in range(2):\n            data.append(self.get_data(par_list[i])) # 0, 1\n        data = torch.cat(data, 2)\n        target = []\n        target.append(self.get_data(par_list[2])) # 2\n        target.append(self.get_data(par_list[3])) # 3\n        target = torch.cat(target, 2)\n        target_target = []\n        target_target.append(self.get_data(par_list[-2])) # 4\n        target_target.append(self.get_data(par_list[-1])) # 5\n        target_target = torch.cat(target_target, 2)\n        filepath = '_'.join(par_list[0].split('/')[-2:])\n        return data, target, target_target, filepath\n\n    def get_data(self, filepath):\n        data = Image.open(filepath)\n        data = data.resize((128, 128))\n        data = np.array(data)\n        data = torch.tensor(data / 255.0)\n        data = data.permute(2, 0, 1).float()\n        return data"
  },
  {
    "path": "eval.py",
    "content": "\n\nimport os\nimport sys\nimport glob\nimport yaml\nimport torch\nimport pprint\nimport shutil\nimport numpy as np\nfrom tqdm import tqdm\nfrom munch import munchify\nfrom collections import OrderedDict\nfrom models import VisDynamicsModel\nfrom models_latentpred import VisLatentDynamicsModel\nfrom dataset import NeuralPhysDataset\nfrom pytorch_lightning import Trainer, seed_everything\nfrom pytorch_lightning.loggers import TensorBoardLogger\n\n\ndef mkdir(folder):\n    if os.path.exists(folder):\n        shutil.rmtree(folder)\n    os.makedirs(folder)\n\ndef load_config(filepath):\n    with open(filepath, 'r') as stream:\n        try:\n            trainer_params = yaml.safe_load(stream)\n            return trainer_params\n        except yaml.YAMLError as exc:\n            print(exc)\n\ndef seed(cfg):\n    torch.manual_seed(cfg.seed)\n    if cfg.if_cuda:\n        torch.cuda.manual_seed(cfg.seed)\n\n\ndef main():\n    config_filepath = str(sys.argv[1])\n    checkpoint_filepath = str(sys.argv[2])\n    checkpoint_filepath = glob.glob(os.path.join(checkpoint_filepath, '*.ckpt'))[0]\n    cfg = load_config(filepath=config_filepath)\n    pprint.pprint(cfg)\n    cfg = munchify(cfg)\n    seed(cfg)\n    seed_everything(cfg.seed)\n\n    log_dir = '_'.join([cfg.log_dir,\n                        cfg.dataset,\n                        cfg.model_name,\n                        str(cfg.seed)])\n\n    model = VisDynamicsModel(lr=cfg.lr,\n                             seed=cfg.seed,\n                             if_cuda=cfg.if_cuda,\n                             if_test=True,\n                             gamma=cfg.gamma,\n                             log_dir=log_dir,\n                             train_batch=cfg.train_batch,\n                             val_batch=cfg.val_batch,\n                             test_batch=cfg.test_batch,\n                             num_workers=cfg.num_workers,\n                             model_name=cfg.model_name,\n                             data_filepath=cfg.data_filepath,\n                             dataset=cfg.dataset,\n                             lr_schedule=cfg.lr_schedule)\n\n    ckpt = torch.load(checkpoint_filepath)\n    if 'refine' in cfg.model_name:\n        ckpt = rename_ckpt_for_multi_models(ckpt)\n        model.model.load_state_dict(ckpt)\n\n        high_dim_checkpoint_filepath = str(sys.argv[3])\n        high_dim_checkpoint_filepath = glob.glob(os.path.join(high_dim_checkpoint_filepath, '*.ckpt'))[0]\n        ckpt = torch.load(high_dim_checkpoint_filepath)\n        ckpt = rename_ckpt_for_multi_models(ckpt)\n        model.high_dim_model.load_state_dict(ckpt)\n\n    else:\n        model.load_state_dict(ckpt['state_dict'])\n\n    model.eval()\n    model.freeze()\n\n    trainer = Trainer(gpus=1,\n                      deterministic=True,\n                      amp_backend='native',\n                      default_root_dir=log_dir,\n                      val_check_interval=1.0)\n\n    trainer.test(model)\n    model.test_save()\n\n\ndef main_latentpred():\n    config_filepath = str(sys.argv[1])\n    checkpoint_filepath = str(sys.argv[2])\n    high_dim_checkpoint_filepath = str(sys.argv[3])\n    refine_checkpoint_filepath = str(sys.argv[4])\n    \n    cfg = load_config(filepath=config_filepath)\n    pprint.pprint(cfg)\n    cfg = munchify(cfg)\n    seed(cfg)\n    seed_everything(cfg.seed)\n\n    log_dir = '_'.join([cfg.log_dir,\n                        cfg.dataset,\n                        cfg.model_name,\n                        str(cfg.seed)])\n\n    model = VisLatentDynamicsModel(lr=cfg.lr,\n                                   seed=cfg.seed,\n                                   if_cuda=cfg.if_cuda,\n                                   if_test=True,\n                                   gamma=cfg.gamma,\n                                   log_dir=log_dir,\n                                   train_batch=cfg.train_batch,\n                                   val_batch=cfg.val_batch,\n                                   test_batch=cfg.test_batch,\n                                   num_workers=cfg.num_workers,\n                                   model_name=cfg.model_name,\n                                   data_filepath=cfg.data_filepath,\n                                   dataset=cfg.dataset,\n                                   lr_schedule=cfg.lr_schedule)\n\n    model.load_model(checkpoint_filepath)\n    model.load_high_dim_refine_model(high_dim_checkpoint_filepath, refine_checkpoint_filepath)\n    model.extract_decoder_from_refine_model()\n    model.eval()\n    model.freeze()\n\n    trainer = Trainer(gpus=1,\n                      deterministic=True,\n                      amp_backend='native',\n                      default_root_dir=log_dir,\n                      val_check_interval=1.0)\n\n    trainer.test(model)\n\n\ndef rename_ckpt_for_multi_models(ckpt):\n    renamed_state_dict = OrderedDict()\n    for k, v in ckpt['state_dict'].items():\n        if 'high_dim_model' in k:\n            name = k.replace('high_dim_model.', '')\n        else:\n            name = k.replace('model.', '')\n        renamed_state_dict[name] = v\n    return renamed_state_dict\n\n# gather latent variables by running training data on the trained high-dim models\ndef gather_latent_from_trained_high_dim_model():\n    config_filepath = str(sys.argv[1])\n    checkpoint_filepath = str(sys.argv[2])\n    checkpoint_filepath = glob.glob(os.path.join(checkpoint_filepath, '*.ckpt'))[0]\n    cfg = load_config(filepath=config_filepath)\n    pprint.pprint(cfg)\n    cfg = munchify(cfg)\n    seed(cfg)\n    seed_everything(cfg.seed)\n\n    log_dir = '_'.join([cfg.log_dir,\n                        cfg.dataset,\n                        cfg.model_name,\n                        str(cfg.seed)])\n\n    model = VisDynamicsModel(lr=cfg.lr,\n                             seed=cfg.seed,\n                             if_cuda=cfg.if_cuda,\n                             if_test=True,\n                             gamma=cfg.gamma,\n                             log_dir=log_dir,\n                             train_batch=cfg.train_batch,\n                             val_batch=cfg.val_batch,\n                             test_batch=cfg.test_batch,\n                             num_workers=cfg.num_workers,\n                             model_name=cfg.model_name,\n                             data_filepath=cfg.data_filepath,\n                             dataset=cfg.dataset,\n                             lr_schedule=cfg.lr_schedule)\n\n    ckpt = torch.load(checkpoint_filepath)\n    model.load_state_dict(ckpt['state_dict'])\n    model = model.to('cuda')\n    model.eval()\n    model.freeze()\n\n    # prepare train and val dataset\n    kwargs = {'num_workers': cfg.num_workers, 'pin_memory': True} if cfg.if_cuda else {}\n    train_dataset = NeuralPhysDataset(data_filepath=cfg.data_filepath,\n                                      flag='train',\n                                      seed=cfg.seed,\n                                      object_name=cfg.dataset)\n    val_dataset = NeuralPhysDataset(data_filepath=cfg.data_filepath,\n                                    flag='val',\n                                    seed=cfg.seed,\n                                    object_name=cfg.dataset)\n    # prepare train and val loader\n    train_loader = torch.utils.data.DataLoader(dataset=train_dataset,\n                                               batch_size=cfg.train_batch,\n                                               shuffle=True,\n                                               **kwargs)\n    val_loader = torch.utils.data.DataLoader(dataset=val_dataset,\n                                             batch_size=cfg.val_batch,\n                                             shuffle=False,\n                                             **kwargs)\n\n    # run train forward pass to save the latent vector for training the refine network later\n    all_filepaths = []\n    all_latents = []\n    var_log_dir = os.path.join(log_dir, 'variables')\n    for batch_idx, (data, target, filepath) in enumerate(tqdm(train_loader)):\n        if cfg.model_name == 'encoder-decoder':\n            output, latent = model.model(data.cuda())\n        if cfg.model_name == 'encoder-decoder-64':\n            output, latent = model.model(data.cuda(), data.cuda(), False)\n        # save the latent vectors\n        all_filepaths.extend(filepath)\n        for idx in range(data.shape[0]):\n            latent_tmp = latent[idx].view(1, -1)[0]\n            latent_tmp = latent_tmp.cpu().detach().numpy()\n            all_latents.append(latent_tmp)\n\n    mkdir(var_log_dir+'_train')\n    np.save(os.path.join(var_log_dir+'_train', 'ids.npy'), all_filepaths)\n    np.save(os.path.join(var_log_dir+'_train', 'latent.npy'), all_latents)\n\n    # run val forward pass to save the latent vector for validating the refine network later\n    all_filepaths = []\n    all_latents = []\n    var_log_dir = os.path.join(log_dir, 'variables')\n    for batch_idx, (data, target, filepath) in enumerate(tqdm(val_loader)):\n        if cfg.model_name == 'encoder-decoder':\n            output, latent = model.model(data.cuda())\n        if cfg.model_name == 'encoder-decoder-64':\n            output, latent = model.model(data.cuda(), data.cuda(), False)\n        # save the latent vectors\n        all_filepaths.extend(filepath)\n        for idx in range(data.shape[0]):\n            latent_tmp = latent[idx].view(1, -1)[0]\n            latent_tmp = latent_tmp.cpu().detach().numpy()\n            all_latents.append(latent_tmp)\n\n    mkdir(var_log_dir+'_val')\n    np.save(os.path.join(var_log_dir+'_val', 'ids.npy'), all_filepaths)\n    np.save(os.path.join(var_log_dir+'_val', 'latent.npy'), all_latents)\n\n\n# gather latent variables by running training data on the trained refine models\ndef gather_latent_from_trained_refine_model():\n    config_filepath = str(sys.argv[1])\n    checkpoint_filepath = str(sys.argv[2])\n    checkpoint_filepath = glob.glob(os.path.join(checkpoint_filepath, '*.ckpt'))[0]\n    cfg = load_config(filepath=config_filepath)\n    pprint.pprint(cfg)\n    cfg = munchify(cfg)\n    seed(cfg)\n    seed_everything(cfg.seed)\n\n    log_dir = '_'.join([cfg.log_dir,\n                        cfg.dataset,\n                        cfg.model_name,\n                        str(cfg.seed)])\n\n    model = VisDynamicsModel(lr=cfg.lr,\n                             seed=cfg.seed,\n                             if_cuda=cfg.if_cuda,\n                             if_test=True,\n                             gamma=cfg.gamma,\n                             log_dir=log_dir,\n                             train_batch=cfg.train_batch,\n                             val_batch=cfg.val_batch,\n                             test_batch=cfg.test_batch,\n                             num_workers=cfg.num_workers,\n                             model_name=cfg.model_name,\n                             data_filepath=cfg.data_filepath,\n                             dataset=cfg.dataset,\n                             lr_schedule=cfg.lr_schedule)\n\n    ckpt = torch.load(checkpoint_filepath)\n    ckpt = rename_ckpt_for_multi_models(ckpt)\n    model.model.load_state_dict(ckpt)\n    high_dim_checkpoint_filepath = str(sys.argv[3])\n    high_dim_checkpoint_filepath = glob.glob(os.path.join(high_dim_checkpoint_filepath, '*.ckpt'))[0]\n    ckpt = torch.load(high_dim_checkpoint_filepath)\n    ckpt = rename_ckpt_for_multi_models(ckpt)\n    model.high_dim_model.load_state_dict(ckpt)\n    model = model.to('cuda')\n    model.eval()\n    model.freeze()\n\n    # prepare train and val dataset\n    kwargs = {'num_workers': cfg.num_workers, 'pin_memory': True} if cfg.if_cuda else {}\n    train_dataset = NeuralPhysDataset(data_filepath=cfg.data_filepath,\n                                      flag='train',\n                                      seed=cfg.seed,\n                                      object_name=cfg.dataset)\n    val_dataset = NeuralPhysDataset(data_filepath=cfg.data_filepath,\n                                    flag='val',\n                                    seed=cfg.seed,\n                                    object_name=cfg.dataset)\n    # prepare train and val loader\n    train_loader = torch.utils.data.DataLoader(dataset=train_dataset,\n                                               batch_size=cfg.train_batch,\n                                               shuffle=True,\n                                               **kwargs)\n    val_loader = torch.utils.data.DataLoader(dataset=val_dataset,\n                                             batch_size=cfg.val_batch,\n                                             shuffle=False,\n                                             **kwargs)\n\n    # run train forward pass to save the latent vector for training data\n    all_filepaths = []\n    all_latents = []\n    all_refine_latents = []\n    all_reconstructed_latents = []\n    var_log_dir = os.path.join(log_dir, 'variables')\n    for batch_idx, (data, target, filepath) in enumerate(tqdm(train_loader)):\n        _, latent = model.high_dim_model(data.cuda(), data.cuda(), False)\n        latent = latent.squeeze(-1).squeeze(-1)\n        latent_reconstructed, latent_latent = model.model(latent)\n        # save the latent vectors\n        all_filepaths.extend(filepath)\n        for idx in range(data.shape[0]):\n            latent_tmp = latent[idx].view(1, -1)[0]\n            latent_tmp = latent_tmp.cpu().detach().numpy()\n            all_latents.append(latent_tmp)\n            # save latent_latent: the latent vector in the refine network\n            latent_latent_tmp = latent_latent[idx].view(1, -1)[0]\n            latent_latent_tmp = latent_latent_tmp.cpu().detach().numpy()\n            all_refine_latents.append(latent_latent_tmp)\n            # save latent_reconstructed: the latent vector reconstructed by the entire refine network\n            latent_reconstructed_tmp = latent_reconstructed[idx].view(1, -1)[0]\n            latent_reconstructed_tmp = latent_reconstructed_tmp.cpu().detach().numpy()\n            all_reconstructed_latents.append(latent_reconstructed_tmp)\n\n    mkdir(var_log_dir+'_train')\n    np.save(os.path.join(var_log_dir+'_train', 'ids.npy'), all_filepaths)\n    np.save(os.path.join(var_log_dir+'_train', 'latent.npy'), all_latents)\n    np.save(os.path.join(var_log_dir+'_train', 'refine_latent.npy'), all_refine_latents)\n    np.save(os.path.join(var_log_dir+'_train', 'reconstructed_latent.npy'), all_reconstructed_latents)\n\n    # run val forward pass to save the latent vector for validation data\n    all_filepaths = []\n    all_latents = []\n    all_refine_latents = []\n    all_reconstructed_latents = []\n    var_log_dir = os.path.join(log_dir, 'variables')\n    for batch_idx, (data, target, filepath) in enumerate(tqdm(val_loader)):\n        _, latent = model.high_dim_model(data.cuda(), data.cuda(), False)\n        latent = latent.squeeze(-1).squeeze(-1)\n        latent_reconstructed, latent_latent = model.model(latent)\n        # save the latent vectors\n        all_filepaths.extend(filepath)\n        for idx in range(data.shape[0]):\n            latent_tmp = latent[idx].view(1, -1)[0]\n            latent_tmp = latent_tmp.cpu().detach().numpy()\n            all_latents.append(latent_tmp)\n            # save latent_latent: the latent vector in the refine network\n            latent_latent_tmp = latent_latent[idx].view(1, -1)[0]\n            latent_latent_tmp = latent_latent_tmp.cpu().detach().numpy()\n            all_refine_latents.append(latent_latent_tmp)\n            # save latent_reconstructed: the latent vector reconstructed by the entire refine network\n            latent_reconstructed_tmp = latent_reconstructed[idx].view(1, -1)[0]\n            latent_reconstructed_tmp = latent_reconstructed_tmp.cpu().detach().numpy()\n            all_reconstructed_latents.append(latent_reconstructed_tmp)\n\n    mkdir(var_log_dir+'_val')\n    np.save(os.path.join(var_log_dir+'_val', 'ids.npy'), all_filepaths)\n    np.save(os.path.join(var_log_dir+'_val', 'latent.npy'), all_latents)\n    np.save(os.path.join(var_log_dir+'_val', 'refine_latent.npy'), all_refine_latents)\n    np.save(os.path.join(var_log_dir+'_val', 'reconstructed_latent.npy'), all_reconstructed_latents)\n\n\nif __name__ == '__main__':\n    if str(sys.argv[4]) == 'eval-train':\n        gather_latent_from_trained_high_dim_model()\n    elif str(sys.argv[4]) == 'eval-refine-train':\n        gather_latent_from_trained_refine_model()\n    elif str(sys.argv[5]) == 'eval-latentpred':\n        main_latentpred()\n    else:\n        main()"
  },
  {
    "path": "main.py",
    "content": "\nimport os\nimport sys\nimport yaml\nimport torch\nimport pprint\nfrom munch import munchify\nfrom models import VisDynamicsModel\nfrom models_latentpred import VisLatentDynamicsModel\nfrom pytorch_lightning import Trainer, seed_everything\nfrom pytorch_lightning.loggers import TensorBoardLogger\nfrom pytorch_lightning.callbacks import ModelCheckpoint\n\n\ndef load_config(filepath):\n    with open(filepath, 'r') as stream:\n        try:\n            trainer_params = yaml.safe_load(stream)\n            return trainer_params\n        except yaml.YAMLError as exc:\n            print(exc)\n\ndef seed(cfg):\n    torch.manual_seed(cfg.seed)\n    if cfg.if_cuda:\n        torch.cuda.manual_seed(cfg.seed)\n\n\ndef main():\n    config_filepath = str(sys.argv[1])\n    cfg = load_config(filepath=config_filepath)\n    pprint.pprint(cfg)\n    cfg = munchify(cfg)\n    seed(cfg)\n    seed_everything(cfg.seed)\n\n    log_dir = '_'.join([cfg.log_dir,\n                        cfg.dataset,\n                        cfg.model_name,\n                        str(cfg.seed)])\n\n    model = VisDynamicsModel(lr=cfg.lr,\n                             seed=cfg.seed,\n                             if_cuda=cfg.if_cuda,\n                             if_test=False,\n                             gamma=cfg.gamma,\n                             log_dir=log_dir,\n                             train_batch=cfg.train_batch,\n                             val_batch=cfg.val_batch,\n                             test_batch=cfg.test_batch,\n                             num_workers=cfg.num_workers,\n                             model_name=cfg.model_name,\n                             data_filepath=cfg.data_filepath,\n                             dataset=cfg.dataset,\n                             lr_schedule=cfg.lr_schedule)\n\n    # define callback for selecting checkpoints during training\n    checkpoint_callback = ModelCheckpoint(\n        filepath=log_dir + \"/lightning_logs/checkpoints/{epoch}_{val_loss}\",\n        verbose=True,\n        monitor='val_loss',\n        mode='min',\n        prefix='')\n\n    # define trainer\n    trainer = Trainer(gpus=cfg.num_gpus,\n                      max_epochs=cfg.epochs,\n                      deterministic=True,\n                      accelerator='ddp',\n                      amp_backend='native',\n                      default_root_dir=log_dir,\n                      val_check_interval=1.0,\n                      checkpoint_callback=checkpoint_callback)\n\n    trainer.fit(model)\n\ndef main_latentpred():\n    config_filepath = str(sys.argv[2])\n    high_dim_checkpoint_filepath = str(sys.argv[3])\n    refine_checkpoint_filepath = str(sys.argv[4])\n    cfg = load_config(filepath=config_filepath)\n    pprint.pprint(cfg)\n    cfg = munchify(cfg)\n    seed(cfg)\n    seed_everything(cfg.seed)\n\n    log_dir = '_'.join([cfg.log_dir,\n                        cfg.dataset,\n                        cfg.model_name,\n                        str(cfg.seed)])\n\n    model = VisLatentDynamicsModel(lr=cfg.lr,\n                                   seed=cfg.seed,\n                                   if_cuda=cfg.if_cuda,\n                                   if_test=False,\n                                   gamma=cfg.gamma,\n                                   log_dir=log_dir,\n                                   train_batch=cfg.train_batch,\n                                   val_batch=cfg.val_batch,\n                                   test_batch=cfg.test_batch,\n                                   num_workers=cfg.num_workers,\n                                   model_name=cfg.model_name,\n                                   data_filepath=cfg.data_filepath,\n                                   dataset=cfg.dataset,\n                                   lr_schedule=cfg.lr_schedule)\n\n    model.load_high_dim_refine_model(high_dim_checkpoint_filepath, refine_checkpoint_filepath)\n\n    # define callback for selecting checkpoints during training\n    checkpoint_callback = ModelCheckpoint(\n        filepath=log_dir + \"/lightning_logs/checkpoints/{epoch}_{val_loss}\",\n        verbose=True,\n        monitor='val_loss',\n        mode='min',\n        prefix='')\n\n    # define trainer\n    trainer = Trainer(gpus=cfg.num_gpus,\n                      max_epochs=cfg.epochs,\n                      deterministic=True,\n                      accelerator='ddp',\n                      amp_backend='native',\n                      default_root_dir=log_dir,\n                      val_check_interval=1.0,\n                      checkpoint_callback=checkpoint_callback)\n\n    trainer.fit(model)\n\nif __name__ == '__main__':\n    if sys.argv[1] == 'latentpred':\n        main_latentpred()\n    else:\n        main()\n"
  },
  {
    "path": "model_utils.py",
    "content": "\nimport torch\nimport numpy as np\nimport torch.nn as nn\nimport torch.nn.functional as F\n\ndef conv2d_bn_relu(inch,outch,kernel_size,stride=1,padding=1):\n    convlayer = torch.nn.Sequential(\n        torch.nn.Conv2d(inch,outch,kernel_size=kernel_size,stride=stride,padding=padding),\n        torch.nn.BatchNorm2d(outch),\n        torch.nn.ReLU()\n    )\n    return convlayer\n\ndef conv2d_bn_sigmoid(inch,outch,kernel_size,stride=1,padding=1):\n    convlayer = torch.nn.Sequential(\n        torch.nn.Conv2d(inch,outch,kernel_size=kernel_size,stride=stride,padding=padding),\n        torch.nn.BatchNorm2d(outch),\n        torch.nn.Sigmoid()\n    )\n    return convlayer\n\ndef deconv_sigmoid(inch,outch,kernel_size,stride=1,padding=1):\n    convlayer = torch.nn.Sequential(\n        torch.nn.ConvTranspose2d(inch,outch,kernel_size=kernel_size,stride=stride,padding=padding),\n        torch.nn.Sigmoid()\n    )\n    return convlayer\n\ndef deconv_relu(inch,outch,kernel_size,stride=1,padding=1):\n    convlayer = torch.nn.Sequential(\n        torch.nn.ConvTranspose2d(inch,outch,kernel_size=kernel_size,stride=stride,padding=padding),\n        torch.nn.BatchNorm2d(outch),\n        torch.nn.ReLU()\n    )\n    return convlayer\n\n\n\nclass EncoderDecoder(torch.nn.Module):\n    def __init__(self, in_channels):\n        super(EncoderDecoder,self).__init__()\n\n        self.conv_stack1 = torch.nn.Sequential(\n            conv2d_bn_relu(in_channels,32,4,stride=2),\n            conv2d_bn_relu(32,32,3)\n        )\n        self.conv_stack2 = torch.nn.Sequential(\n            conv2d_bn_relu(32,32,4,stride=2),\n            conv2d_bn_relu(32,32,3)\n        )\n        self.conv_stack3 = torch.nn.Sequential(\n            conv2d_bn_relu(32,64,4,stride=2),\n            conv2d_bn_relu(64,64,3)\n        )\n        self.conv_stack4 = torch.nn.Sequential(\n            conv2d_bn_relu(64,128,4,stride=2),\n            conv2d_bn_relu(128,128,3),\n        )\n        self.conv_stack5 = torch.nn.Sequential(\n            conv2d_bn_relu(128,128,(3,4),stride=(1,2)),\n            conv2d_bn_relu(128,128,3),\n        )\n        \n        self.deconv_5 = deconv_relu(128,64,(3,4),stride=(1,2))\n        self.deconv_4 = deconv_relu(67,64,4,stride=2)\n        self.deconv_3 = deconv_relu(67,32,4,stride=2)\n        self.deconv_2 = deconv_relu(35,16,4,stride=2)\n        self.deconv_1 = deconv_sigmoid(19,3,4,stride=2)\n\n        self.predict_5 = torch.nn.Conv2d(128,3,3,stride=1,padding=1)\n        self.predict_4 = torch.nn.Conv2d(67,3,3,stride=1,padding=1)\n        self.predict_3 = torch.nn.Conv2d(67,3,3,stride=1,padding=1)\n        self.predict_2 = torch.nn.Conv2d(35,3,3,stride=1,padding=1)\n\n        self.up_sample_5 = torch.nn.Sequential(\n            torch.nn.ConvTranspose2d(3,3,(3,4),stride=(1,2),padding=1,bias=False),\n            torch.nn.Sigmoid()\n        )\n        self.up_sample_4 = torch.nn.Sequential(\n            torch.nn.ConvTranspose2d(3,3,4,stride=2,padding=1,bias=False),\n            torch.nn.Sigmoid()\n        )\n        self.up_sample_3 = torch.nn.Sequential(\n            torch.nn.ConvTranspose2d(3,3,4,stride=2,padding=1,bias=False),\n            torch.nn.Sigmoid()\n        )\n        self.up_sample_2 = torch.nn.Sequential(\n            torch.nn.ConvTranspose2d(3,3,4,stride=2,padding=1,bias=False),\n            torch.nn.Sigmoid()\n        )\n\n\n    def encoder(self, x):\n        conv1_out = self.conv_stack1(x)\n        conv2_out = self.conv_stack2(conv1_out)\n        conv3_out = self.conv_stack3(conv2_out)\n        conv4_out = self.conv_stack4(conv3_out)\n        conv5_out = self.conv_stack5(conv4_out)\n        return conv5_out\n\n    def decoder(self, x):\n        deconv5_out = self.deconv_5(x)\n        predict_5_out = self.up_sample_5(self.predict_5(x))\n\n        concat_5 = torch.cat([deconv5_out, predict_5_out],dim=1)\n        deconv4_out = self.deconv_4(concat_5)\n        predict_4_out = self.up_sample_4(self.predict_4(concat_5))\n\n        concat_4 = torch.cat([deconv4_out,predict_4_out],dim=1)\n        deconv3_out = self.deconv_3(concat_4)\n        predict_3_out = self.up_sample_3(self.predict_3(concat_4))\n\n        concat2 = torch.cat([deconv3_out,predict_3_out],dim=1)\n        deconv2_out = self.deconv_2(concat2)\n        predict_2_out = self.up_sample_2(self.predict_2(concat2))\n\n        concat1 = torch.cat([deconv2_out,predict_2_out],dim=1)\n        predict_out = self.deconv_1(concat1)\n        return predict_out\n        \n\n    def forward(self,x):\n        latent = self.encoder(x)\n        out = self.decoder(latent)\n        return out, latent\n\nclass EncoderDecoder64x1x1(torch.nn.Module):\n    def __init__(self, in_channels):\n        super(EncoderDecoder64x1x1,self).__init__()\n\n        self.conv_stack1 = torch.nn.Sequential(\n            conv2d_bn_relu(in_channels,32,4,stride=2),\n            conv2d_bn_relu(32,32,3)\n        )\n        self.conv_stack2 = torch.nn.Sequential(\n            conv2d_bn_relu(32,32,4,stride=2),\n            conv2d_bn_relu(32,32,3)\n        )\n        self.conv_stack3 = torch.nn.Sequential(\n            conv2d_bn_relu(32,64,4,stride=2),\n            conv2d_bn_relu(64,64,3)\n        )\n        self.conv_stack4 = torch.nn.Sequential(\n            conv2d_bn_relu(64,64,4,stride=2),\n            conv2d_bn_relu(64,64,3),\n        )\n\n        self.conv_stack5 = torch.nn.Sequential(\n            conv2d_bn_relu(64,64,4,stride=2),\n            conv2d_bn_relu(64,64,3),\n        )\n\n        self.conv_stack6 = torch.nn.Sequential(\n            conv2d_bn_relu(64,64,4,stride=2),\n            conv2d_bn_relu(64,64,3),\n        )\n\n        self.conv_stack7 = torch.nn.Sequential(\n            conv2d_bn_relu(64,64,4,stride=2),\n            conv2d_bn_relu(64,64,3),\n        )\n\n        self.conv_stack8 = torch.nn.Sequential(\n            conv2d_bn_relu(64, 64, (3,4), stride=(1,2)),\n            conv2d_bn_relu(64, 64, 3),\n        )\n        \n        self.deconv_8 = deconv_relu(64,64,(3,4),stride=(1,2))\n        self.deconv_7 = deconv_relu(67,64,4,stride=2)\n        self.deconv_6 = deconv_relu(67,64,4,stride=2)\n        self.deconv_5 = deconv_relu(67,64,4,stride=2)\n        self.deconv_4 = deconv_relu(67,64,4,stride=2)\n        self.deconv_3 = deconv_relu(67,32,4,stride=2)\n        self.deconv_2 = deconv_relu(35,16,4,stride=2)\n        self.deconv_1 = deconv_sigmoid(19,3,4,stride=2)\n\n        self.predict_8 = torch.nn.Conv2d(64,3,3,stride=1,padding=1)\n        self.predict_7 = torch.nn.Conv2d(67,3,3,stride=1,padding=1)\n        self.predict_6 = torch.nn.Conv2d(67,3,3,stride=1,padding=1)\n        self.predict_5 = torch.nn.Conv2d(67,3,3,stride=1,padding=1)\n        self.predict_4 = torch.nn.Conv2d(67,3,3,stride=1,padding=1)\n        self.predict_3 = torch.nn.Conv2d(67,3,3,stride=1,padding=1)\n        self.predict_2 = torch.nn.Conv2d(35,3,3,stride=1,padding=1)\n\n        self.up_sample_8 = torch.nn.Sequential(\n            torch.nn.ConvTranspose2d(3,3,(3,4),stride=(1,2),padding=1,bias=False),\n            torch.nn.Sigmoid()\n        )\n\n        self.up_sample_7 = torch.nn.Sequential(\n            torch.nn.ConvTranspose2d(3,3,4,stride=2,padding=1,bias=False),\n            torch.nn.Sigmoid()\n        )\n\n        self.up_sample_6 = torch.nn.Sequential(\n            torch.nn.ConvTranspose2d(3,3,4,stride=2,padding=1,bias=False),\n            torch.nn.Sigmoid()\n        )\n\n        self.up_sample_5 = torch.nn.Sequential(\n            torch.nn.ConvTranspose2d(3,3,4,stride=2,padding=1,bias=False),\n            torch.nn.Sigmoid()\n        )\n\n        self.up_sample_4 = torch.nn.Sequential(\n            torch.nn.ConvTranspose2d(3,3,4,stride=2,padding=1,bias=False),\n            torch.nn.Sigmoid()\n        )\n        self.up_sample_3 = torch.nn.Sequential(\n            torch.nn.ConvTranspose2d(3,3,4,stride=2,padding=1,bias=False),\n            torch.nn.Sigmoid()\n        )\n        self.up_sample_2 = torch.nn.Sequential(\n            torch.nn.ConvTranspose2d(3,3,4,stride=2,padding=1,bias=False),\n            torch.nn.Sigmoid()\n        )\n\n\n    def encoder(self, x):\n        conv1_out = self.conv_stack1(x)\n        conv2_out = self.conv_stack2(conv1_out)\n        conv3_out = self.conv_stack3(conv2_out)\n        conv4_out = self.conv_stack4(conv3_out)\n        conv5_out = self.conv_stack5(conv4_out)\n        conv6_out = self.conv_stack6(conv5_out)\n        conv7_out = self.conv_stack7(conv6_out)\n        conv8_out = self.conv_stack8(conv7_out)\n        return conv8_out\n\n    def decoder(self, x):\n\n        deconv8_out = self.deconv_8(x)\n        predict_8_out = self.up_sample_8(self.predict_8(x))\n\n        concat_7 = torch.cat([deconv8_out, predict_8_out], dim=1)\n        deconv7_out = self.deconv_7(concat_7)\n        predict_7_out = self.up_sample_7(self.predict_7(concat_7))\n\n        concat_6 = torch.cat([deconv7_out,predict_7_out],dim=1)\n        deconv6_out = self.deconv_6(concat_6)\n        predict_6_out = self.up_sample_6(self.predict_6(concat_6))\n\n        concat_5 = torch.cat([deconv6_out,predict_6_out],dim=1)\n        deconv5_out = self.deconv_5(concat_5)\n        predict_5_out = self.up_sample_5(self.predict_5(concat_5))\n\n        concat_4 = torch.cat([deconv5_out,predict_5_out],dim=1)\n        deconv4_out = self.deconv_4(concat_4)\n        predict_4_out = self.up_sample_4(self.predict_4(concat_4))\n\n        concat_3 = torch.cat([deconv4_out,predict_4_out],dim=1)\n        deconv3_out = self.deconv_3(concat_3)\n        predict_3_out = self.up_sample_3(self.predict_3(concat_3))\n\n        concat2 = torch.cat([deconv3_out,predict_3_out],dim=1)\n        deconv2_out = self.deconv_2(concat2)\n        predict_2_out = self.up_sample_2(self.predict_2(concat2))\n\n        concat1 = torch.cat([deconv2_out,predict_2_out],dim=1)\n        predict_out = self.deconv_1(concat1)\n        return predict_out\n        \n\n    def forward(self,x, reconstructed_latent, refine_latent):\n        if refine_latent != True:\n            latent = self.encoder(x)\n            out = self.decoder(latent)\n            return out, latent\n        else:\n            latent = self.encoder(x)\n            out = self.decoder(reconstructed_latent)\n            return out, latent\n\nclass SirenLayer(nn.Module):\n    def __init__(self, in_f, out_f, w0=30, is_first=False, is_last=False):\n        super().__init__()\n        self.in_f = in_f\n        self.w0 = w0\n        self.linear = nn.Linear(in_f, out_f)\n        self.is_first = is_first\n        self.is_last = is_last\n        self.init_weights()\n    \n    def init_weights(self):\n        b = 1 / self.in_f if self.is_first else np.sqrt(6 / self.in_f) / self.w0\n        with torch.no_grad():\n            self.linear.weight.uniform_(-b, b)\n\n    def forward(self, x):\n        x = self.linear(x)\n        return x if self.is_last else torch.sin(self.w0 * x)\n\nclass RefineDoublePendulumModel(torch.nn.Module):\n    def __init__(self, in_channels):\n        super(RefineDoublePendulumModel, self).__init__()\n\n        self.layer1 = SirenLayer(in_channels, 128, is_first=True)\n        self.layer2 = SirenLayer(128, 64)\n        self.layer3 = SirenLayer(64, 32)\n        self.layer4 = SirenLayer(32, 4)\n        self.layer5 = SirenLayer(4, 32)\n        self.layer6 = SirenLayer(32, 64)\n        self.layer7 = SirenLayer(64, 128)\n        self.layer8 = SirenLayer(128, in_channels, is_last=True)\n\n    def forward(self, x):\n        x = self.layer1(x)\n        x = self.layer2(x)\n        x = self.layer3(x)\n        latent = self.layer4(x)\n        x = self.layer5(latent)\n        x = self.layer6(x)\n        x = self.layer7(x)\n        x = self.layer8(x)\n        return x, latent\n\nclass RefineElasticPendulumModel(torch.nn.Module):\n    def __init__(self, in_channels):\n        super(RefineElasticPendulumModel, self).__init__()\n\n        self.layer1 = SirenLayer(in_channels, 128, is_first=True)\n        self.layer2 = SirenLayer(128, 64)\n        self.layer3 = SirenLayer(64, 32)\n        self.layer4 = SirenLayer(32, 6)\n        self.layer5 = SirenLayer(6, 32)\n        self.layer6 = SirenLayer(32, 64)\n        self.layer7 = SirenLayer(64, 128)\n        self.layer8 = SirenLayer(128, in_channels, is_last=True)\n\n    def forward(self, x):\n        x = self.layer1(x)\n        x = self.layer2(x)\n        x = self.layer3(x)\n        latent = self.layer4(x)\n        x = self.layer5(latent)\n        x = self.layer6(x)\n        x = self.layer7(x)\n        x = self.layer8(x)\n        return x, latent\n\nclass RefineReactionDiffusionModel(torch.nn.Module):\n    def __init__(self, in_channels):\n        super(RefineReactionDiffusionModel, self).__init__()\n\n        self.layer1 = SirenLayer(in_channels, 128, is_first=True)\n        self.layer2 = SirenLayer(128, 64)\n        self.layer3 = SirenLayer(64, 32)\n        self.layer4 = SirenLayer(32, 2)\n        self.layer5 = SirenLayer(2, 32)\n        self.layer6 = SirenLayer(32, 64)\n        self.layer7 = SirenLayer(64, 128)\n        self.layer8 = SirenLayer(128, in_channels, is_last=True)\n\n    def forward(self, x):\n        x = self.layer1(x)\n        x = self.layer2(x)\n        x = self.layer3(x)\n        latent = self.layer4(x)\n        x = self.layer5(latent)\n        x = self.layer6(x)\n        x = self.layer7(x)\n        x = self.layer8(x)\n        return x, latent\n\nclass RefineSwingStickNonMagneticModel(torch.nn.Module):\n    def __init__(self, in_channels):\n        super(RefineSwingStickNonMagneticModel, self).__init__()\n\n        self.layer1 = SirenLayer(in_channels, 128, is_first=True)\n        self.layer2 = SirenLayer(128, 64)\n        self.layer3 = SirenLayer(64, 32)\n        self.layer4 = SirenLayer(32, 4)\n        self.layer5 = SirenLayer(4, 32)\n        self.layer6 = SirenLayer(32, 64)\n        self.layer7 = SirenLayer(64, 128)\n        self.layer8 = SirenLayer(128, in_channels, is_last=True)\n\n    def forward(self, x):\n        x = self.layer1(x)\n        x = self.layer2(x)\n        x = self.layer3(x)\n        latent = self.layer4(x)\n        x = self.layer5(latent)\n        x = self.layer6(x)\n        x = self.layer7(x)\n        x = self.layer8(x)\n        return x, latent\n\nclass RefineSinglePendulumModel(torch.nn.Module):\n    def __init__(self, in_channels):\n        super(RefineSinglePendulumModel, self).__init__()\n\n        self.layer1 = SirenLayer(in_channels, 128, is_first=True)\n        self.layer2 = SirenLayer(128, 64)\n        self.layer3 = SirenLayer(64, 32)\n        self.layer4 = SirenLayer(32, 2)\n        self.layer5 = SirenLayer(2, 32)\n        self.layer6 = SirenLayer(32, 64)\n        self.layer7 = SirenLayer(64, 128)\n        self.layer8 = SirenLayer(128, in_channels, is_last=True)\n\n    def forward(self, x):\n        x = self.layer1(x)\n        x = self.layer2(x)\n        x = self.layer3(x)\n        latent = self.layer4(x)\n        x = self.layer5(latent)\n        x = self.layer6(x)\n        x = self.layer7(x)\n        x = self.layer8(x)\n        return x, latent\n\n\nclass RefineCircularMotionModel(torch.nn.Module):\n    def __init__(self, in_channels):\n        super(RefineCircularMotionModel, self).__init__()\n\n        self.layer1 = SirenLayer(in_channels, 128, is_first=True)\n        self.layer2 = SirenLayer(128, 64)\n        self.layer3 = SirenLayer(64, 32)\n        self.layer4 = SirenLayer(32, 2)\n        self.layer5 = SirenLayer(2, 32)\n        self.layer6 = SirenLayer(32, 64)\n        self.layer7 = SirenLayer(64, 128)\n        self.layer8 = SirenLayer(128, in_channels, is_last=True)\n\n    def forward(self, x):\n        x = self.layer1(x)\n        x = self.layer2(x)\n        x = self.layer3(x)\n        latent = self.layer4(x)\n        x = self.layer5(latent)\n        x = self.layer6(x)\n        x = self.layer7(x)\n        x = self.layer8(x)\n        return x, latent\n\nclass RefineAirDancerModel(torch.nn.Module):\n    def __init__(self, in_channels):\n        super(RefineAirDancerModel, self).__init__()\n\n        self.layer1 = SirenLayer(in_channels, 128, is_first=True)\n        self.layer2 = SirenLayer(128, 64)\n        self.layer3 = SirenLayer(64, 32)\n        self.layer4 = SirenLayer(32, 8)\n        self.layer5 = SirenLayer(8, 32)\n        self.layer6 = SirenLayer(32, 64)\n        self.layer7 = SirenLayer(64, 128)\n        self.layer8 = SirenLayer(128, in_channels, is_last=True)\n\n    def forward(self, x):\n        x = self.layer1(x)\n        x = self.layer2(x)\n        x = self.layer3(x)\n        latent = self.layer4(x)\n        x = self.layer5(latent)\n        x = self.layer6(x)\n        x = self.layer7(x)\n        x = self.layer8(x)\n        return x, latent\n\nclass RefineLavaLampModel(torch.nn.Module):\n    def __init__(self, in_channels):\n        super(RefineLavaLampModel, self).__init__()\n\n        self.layer1 = SirenLayer(in_channels, 128, is_first=True)\n        self.layer2 = SirenLayer(128, 64)\n        self.layer3 = SirenLayer(64, 32)\n        self.layer4 = SirenLayer(32, 8)\n        self.layer5 = SirenLayer(8, 32)\n        self.layer6 = SirenLayer(32, 64)\n        self.layer7 = SirenLayer(64, 128)\n        self.layer8 = SirenLayer(128, in_channels, is_last=True)\n\n    def forward(self, x):\n        x = self.layer1(x)\n        x = self.layer2(x)\n        x = self.layer3(x)\n        latent = self.layer4(x)\n        x = self.layer5(latent)\n        x = self.layer6(x)\n        x = self.layer7(x)\n        x = self.layer8(x)\n        return x, latent\n\nclass RefineFireModel(torch.nn.Module):\n    def __init__(self, in_channels):\n        super(RefineFireModel, self).__init__()\n\n        self.layer1 = SirenLayer(in_channels, 128, is_first=True)\n        self.layer2 = SirenLayer(128, 64)\n        self.layer3 = SirenLayer(64, 32)\n        self.layer4 = SirenLayer(32, 24)\n        self.layer5 = SirenLayer(24, 32)\n        self.layer6 = SirenLayer(32, 64)\n        self.layer7 = SirenLayer(64, 128)\n        self.layer8 = SirenLayer(128, in_channels, is_last=True)\n\n    def forward(self, x):\n        x = self.layer1(x)\n        x = self.layer2(x)\n        x = self.layer3(x)\n        latent = self.layer4(x)\n        x = self.layer5(latent)\n        x = self.layer6(x)\n        x = self.layer7(x)\n        x = self.layer8(x)\n        return x, latent\n\n\nclass RefineModelReLU(torch.nn.Module):\n    def __init__(self, in_channels):\n        super(RefineModelReLU, self).__init__()\n\n        self.layer1 = nn.Linear(in_channels, 128)\n        self.relu1 = nn.ReLU()\n        self.layer2 = nn.Linear(128, 64)\n        self.relu2 = nn.ReLU()\n        self.layer3 = nn.Linear(64, 4)\n        self.relu3 = nn.ReLU()\n        self.layer4 = nn.Linear(4, 64)\n        self.relu4 = nn.ReLU()\n        self.layer5 = nn.Linear(64, 128)\n        self.relu5 = nn.ReLU()\n        self.layer6 = nn.Linear(128, in_channels)\n\n    def forward(self, x):\n        x = self.relu1(self.layer1(x))\n        x = self.relu2(self.layer2(x))\n        latent = self.relu3(self.layer3(x))\n        x = self.relu4(self.layer4(latent))\n        x = self.relu5(self.layer5(x))\n        x = self.layer6(x)\n        return x, latent\n\n\nclass LatentPredModel(torch.nn.Module):\n    def __init__(self, in_channels):\n        super(LatentPredModel, self).__init__()\n\n        self.layer1 = nn.Linear(in_channels, 32)\n        self.relu1 = nn.ReLU()\n        self.layer2 = nn.Linear(32, 64)\n        self.relu2 = nn.ReLU()\n        self.layer3 = nn.Linear(64, 64)\n        self.relu3 = nn.ReLU()\n        self.layer4 = nn.Linear(64, 64)\n        self.relu4 = nn.ReLU()\n        self.layer5 = nn.Linear(64, 32)\n        self.relu5 = nn.ReLU()\n        self.layer6 = nn.Linear(32, in_channels)\n\n    def forward(self, x):\n        x = self.relu1(self.layer1(x))\n        x = self.relu2(self.layer2(x))\n        x = self.relu3(self.layer3(x))\n        x = self.relu4(self.layer4(x))\n        x = self.relu5(self.layer5(x))\n        x = self.layer6(x)\n        return x\n"
  },
  {
    "path": "models.py",
    "content": "\nimport os\nimport torch\nimport shutil\nimport numpy as np\nfrom torch import nn\nimport pytorch_lightning as pl\nimport torch.nn.functional as F\nfrom torchvision import transforms\nimport torchvision.models as models\nfrom collections import OrderedDict\nfrom torch.utils.data import DataLoader\nfrom torchvision.utils import save_image\nfrom dataset import NeuralPhysDataset\nfrom model_utils import (EncoderDecoder,\n                         EncoderDecoder64x1x1,\n                         RefineDoublePendulumModel,\n                         RefineSinglePendulumModel,\n                         RefineCircularMotionModel,\n                         RefineModelReLU,\n                         RefineSwingStickNonMagneticModel,\n                         RefineAirDancerModel,\n                         RefineLavaLampModel,\n                         RefineFireModel,\n                         RefineElasticPendulumModel,\n                         RefineReactionDiffusionModel)\n\n\ndef mkdir(folder):\n    if os.path.exists(folder):\n        shutil.rmtree(folder)\n    os.makedirs(folder)\n\nclass VisDynamicsModel(pl.LightningModule):\n\n    def __init__(self,\n                 lr: float=1e-4,\n                 seed: int=1,\n                 if_cuda: bool=True,\n                 if_test: bool=False,\n                 gamma: float=0.5,\n                 log_dir: str='logs',\n                 train_batch: int=512,\n                 val_batch: int=256,\n                 test_batch: int=256,\n                 num_workers: int=8,\n                 model_name: str='encoder-decoder-64',\n                 data_filepath: str='data',\n                 dataset: str='single_pendulum',\n                 lr_schedule: list=[20, 50, 100]) -> None:\n        super().__init__()\n        self.save_hyperparameters()\n        self.kwargs = {'num_workers': self.hparams.num_workers, 'pin_memory': True} if self.hparams.if_cuda else {}\n        # create visualization saving folder if testing\n        self.pred_log_dir = os.path.join(self.hparams.log_dir, 'predictions')\n        self.var_log_dir = os.path.join(self.hparams.log_dir, 'variables')\n        if not self.hparams.if_test:\n            mkdir(self.pred_log_dir)\n            mkdir(self.var_log_dir)\n\n        self.__build_model()\n\n    def __build_model(self):\n        # model\n        if self.hparams.model_name == 'encoder-decoder':\n            self.model = EncoderDecoder(in_channels=3)\n        if self.hparams.model_name == 'encoder-decoder-64':\n            self.model = EncoderDecoder64x1x1(in_channels=3)\n        if self.hparams.model_name == 'refine-64' and self.hparams.dataset == 'single_pendulum':\n            self.model = RefineSinglePendulumModel(in_channels=64)\n        if self.hparams.model_name == 'refine-64' and self.hparams.dataset == 'double_pendulum':\n            self.model = RefineDoublePendulumModel(in_channels=64)\n        if self.hparams.model_name == 'refine-64' and self.hparams.dataset == 'circular_motion':\n            self.model = RefineCircularMotionModel(in_channels=64)\n        if self.hparams.model_name == 'refine-64' and self.hparams.dataset == 'swingstick_non_magnetic':\n            self.model = RefineSwingStickNonMagneticModel(in_channels=64)\n        if self.hparams.model_name == 'refine-64' and self.hparams.dataset == 'air_dancer':\n            self.model = RefineAirDancerModel(in_channels=64)\n        if self.hparams.model_name == 'refine-64' and self.hparams.dataset == 'lava_lamp':\n            self.model = RefineLavaLampModel(in_channels=64)\n        if self.hparams.model_name == 'refine-64' and self.hparams.dataset == 'fire':\n            self.model = RefineFireModel(in_channels=64)\n        if self.hparams.model_name == 'refine-64' and self.hparams.dataset == 'elastic_pendulum':\n            self.model = RefineElasticPendulumModel(in_channels=64)\n        if self.hparams.model_name == 'refine-64' and self.hparams.dataset == 'reaction_diffusion':\n            self.model = RefineReactionDiffusionModel(in_channels=64)\n        if 'refine' in self.hparams.model_name and self.hparams.if_test:\n            self.high_dim_model = EncoderDecoder64x1x1(in_channels=3)\n        # loss\n        self.loss_func = nn.MSELoss()\n\n    def train_forward(self, x):\n        if self.hparams.model_name == 'encoder-decoder' or 'refine' in self.hparams.model_name:\n            output, latent = self.model(x)\n        if self.hparams.model_name == 'encoder-decoder-64':\n            output, latent = self.model(x, x, False)\n        return output, latent\n\n    def training_step(self, batch, batch_idx):\n        data, target, filepath = batch\n        output, latent = self.train_forward(data)\n\n        train_loss = self.loss_func(output, target)\n        self.log('train_loss', train_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)\n        return train_loss\n\n    def validation_step(self, batch, batch_idx):\n        data, target, filepath = batch\n        output, latent = self.train_forward(data)\n\n        val_loss = self.loss_func(output, target)\n        self.log('val_loss', val_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)\n        return val_loss\n\n    def test_step(self, batch, batch_idx):\n        \n        if self.hparams.model_name == 'encoder-decoder' or self.hparams.model_name == 'encoder-decoder-64':\n            data, target, filepath = batch\n            if self.hparams.model_name == 'encoder-decoder':\n                output, latent = self.model(data)\n            if self.hparams.model_name == 'encoder-decoder-64':\n                output, latent = self.model(data, data, False)\n            test_loss = self.loss_func(output, target)\n            self.log('test_loss', test_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)\n            # save the output images and latent vectors\n            self.all_filepaths.extend(filepath)\n            for idx in range(data.shape[0]):\n                comparison = torch.cat([data[idx,:, :, :128].unsqueeze(0),\n                                        data[idx,:, :, 128:].unsqueeze(0),\n                                        target[idx, :, :, :128].unsqueeze(0),\n                                        target[idx, :, :, 128:].unsqueeze(0),\n                                        output[idx, :, :, :128].unsqueeze(0),\n                                        output[idx, :, :, 128:].unsqueeze(0)])\n                save_image(comparison.cpu(), os.path.join(self.pred_log_dir, filepath[idx]), nrow=1)\n                latent_tmp = latent[idx].view(1, -1)[0]\n                latent_tmp = latent_tmp.cpu().detach().numpy()\n                self.all_latents.append(latent_tmp)\n\n        if 'refine' in self.hparams.model_name:\n            data, target, filepath = batch\n            _, latent = self.high_dim_model(data, data, False)\n            latent = latent.squeeze(-1).squeeze(-1)\n            latent_reconstructed, latent_latent = self.model(latent)\n            output, _ = self.high_dim_model(data, latent_reconstructed.unsqueeze(2).unsqueeze(3), True)\n            # calculate losses\n            pixel_reconstruction_loss = self.loss_func(output, target)\n            test_loss = self.loss_func(latent_reconstructed, latent)\n            self.log('test_loss', test_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)\n            self.log('pixel_reconstruction_loss', pixel_reconstruction_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)\n            # save the output images and latent vectors\n            self.all_filepaths.extend(filepath)\n            for idx in range(data.shape[0]):\n                comparison = torch.cat([data[idx, :, :, :128].unsqueeze(0),\n                                        data[idx, :, :, 128:].unsqueeze(0),\n                                        target[idx, :, :, :128].unsqueeze(0),\n                                        target[idx, :, :, 128:].unsqueeze(0),\n                                        output[idx, :, :, :128].unsqueeze(0),\n                                        output[idx, :, :, 128:].unsqueeze(0)])\n                save_image(comparison.cpu(), os.path.join(self.pred_log_dir, filepath[idx]), nrow=1)\n                latent_tmp = latent[idx].view(1, -1)[0]\n                latent_tmp = latent_tmp.cpu().detach().numpy()\n                self.all_latents.append(latent_tmp)\n                # save latent_latent: the latent vector in the refine network\n                latent_latent_tmp = latent_latent[idx].view(1, -1)[0]\n                latent_latent_tmp = latent_latent_tmp.cpu().detach().numpy()\n                self.all_refine_latents.append(latent_latent_tmp)\n                # save latent_reconstructed: the latent vector reconstructed by the entire refine network\n                latent_reconstructed_tmp = latent_reconstructed[idx].view(1, -1)[0]\n                latent_reconstructed_tmp = latent_reconstructed_tmp.cpu().detach().numpy()\n                self.all_reconstructed_latents.append(latent_reconstructed_tmp)\n\n\n    def test_save(self):\n        if self.hparams.model_name == 'encoder-decoder' or self.hparams.model_name == 'encoder-decoder-64':\n            np.save(os.path.join(self.var_log_dir, 'ids.npy'), self.all_filepaths)\n            np.save(os.path.join(self.var_log_dir, 'latent.npy'), self.all_latents)\n        if 'refine' in self.hparams.model_name:\n            np.save(os.path.join(self.var_log_dir, 'ids.npy'), self.all_filepaths)\n            np.save(os.path.join(self.var_log_dir, 'latent.npy'), self.all_latents)\n            np.save(os.path.join(self.var_log_dir, 'refine_latent.npy'), self.all_refine_latents)\n            np.save(os.path.join(self.var_log_dir, 'reconstructed_latent.npy'), self.all_reconstructed_latents)\n\n    def configure_optimizers(self):\n        optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.lr)\n        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=self.hparams.lr_schedule, gamma=self.hparams.gamma)\n        return [optimizer], [scheduler]\n    \n    def paths_to_tuple(self, paths):\n        new_paths = []\n        for i in range(len(paths)):\n            tmp = paths[i].split('.')[0].split('_')\n            new_paths.append((int(tmp[0]), int(tmp[1])))\n        return new_paths\n\n    def setup(self, stage=None):\n\n        if stage == 'fit':\n            # for the training of the refine network, we need to have the latent data as the dataset\n            if 'refine' in self.hparams.model_name:\n                high_dim_var_log_dir = self.var_log_dir.replace('refine', 'encoder-decoder')\n                train_data = torch.FloatTensor(np.load(os.path.join(high_dim_var_log_dir+'_train', 'latent.npy')))\n                train_target = torch.FloatTensor(np.load(os.path.join(high_dim_var_log_dir+'_train', 'latent.npy')))\n                val_data = torch.FloatTensor(np.load(os.path.join(high_dim_var_log_dir+'_val', 'latent.npy')))\n                val_target = torch.FloatTensor(np.load(os.path.join(high_dim_var_log_dir+'_val', 'latent.npy')))\n                train_filepaths = list(np.load(os.path.join(high_dim_var_log_dir+'_train', 'ids.npy')))\n                val_filepaths = list(np.load(os.path.join(high_dim_var_log_dir+'_val', 'ids.npy')))\n                # convert the file strings into tuple so that we can use TensorDataset to load everything together\n                train_filepaths = torch.Tensor(self.paths_to_tuple(train_filepaths))\n                val_filepaths = torch.Tensor(self.paths_to_tuple(val_filepaths))\n                self.train_dataset = torch.utils.data.TensorDataset(train_data, train_target, train_filepaths)\n                self.val_dataset = torch.utils.data.TensorDataset(val_data, val_target, val_filepaths)\n            else:\n                self.train_dataset = NeuralPhysDataset(data_filepath=self.hparams.data_filepath,\n                                                       flag='train',\n                                                       seed=self.hparams.seed,\n                                                       object_name=self.hparams.dataset)\n                self.val_dataset = NeuralPhysDataset(data_filepath=self.hparams.data_filepath,\n                                                     flag='val',\n                                                     seed=self.hparams.seed,\n                                                     object_name=self.hparams.dataset)\n\n        if stage == 'test':\n            self.test_dataset = NeuralPhysDataset(data_filepath=self.hparams.data_filepath,\n                                                  flag='test',\n                                                  seed=self.hparams.seed,\n                                                  object_name=self.hparams.dataset)\n            \n            # initialize lists for saving variables and latents during testing\n            self.all_filepaths = []\n            self.all_latents = []\n            self.all_refine_latents = []\n            self.all_reconstructed_latents = []\n\n    def train_dataloader(self):\n        train_loader = torch.utils.data.DataLoader(dataset=self.train_dataset,\n                                                   batch_size=self.hparams.train_batch,\n                                                   shuffle=True,\n                                                   **self.kwargs)\n        return train_loader\n\n    def val_dataloader(self):\n        val_loader = torch.utils.data.DataLoader(dataset=self.val_dataset,\n                                                 batch_size=self.hparams.val_batch,\n                                                 shuffle=False,\n                                                 **self.kwargs)\n        return val_loader\n\n    def test_dataloader(self):\n        test_loader = torch.utils.data.DataLoader(dataset=self.test_dataset,\n                                                  batch_size=self.hparams.test_batch,\n                                                  shuffle=False,\n                                                  **self.kwargs)\n        return test_loader"
  },
  {
    "path": "models_latentpred.py",
    "content": "\nimport os\nimport glob\nimport torch\nimport shutil\nimport numpy as np\nfrom torch import nn\nimport pytorch_lightning as pl\nimport torch.nn.functional as F\nfrom torchvision import transforms\nimport torchvision.models as models\nfrom collections import OrderedDict\nfrom torch.utils.data import DataLoader\nfrom torchvision.utils import save_image\nfrom dataset import NeuralPhysDataset, NeuralPhysLatentDynamicsDataset\nfrom model_utils import (EncoderDecoder,\n                         EncoderDecoder64x1x1,\n                         RefineDoublePendulumModel,\n                         RefineSinglePendulumModel,\n                         RefineCircularMotionModel,\n                         RefineModelReLU,\n                         RefineSwingStickNonMagneticModel,\n                         RefineAirDancerModel,\n                         RefineLavaLampModel,\n                         RefineFireModel,\n                         RefineElasticPendulumModel,\n                         RefineReactionDiffusionModel,\n                         LatentPredModel)\n\n\ndef mkdir(folder):\n    if os.path.exists(folder):\n        shutil.rmtree(folder)\n    os.makedirs(folder)\n\ndef rename_ckpt_for_multi_models(ckpt):\n    renamed_state_dict = OrderedDict()\n    for k, v in ckpt['state_dict'].items():\n        if k.split('.')[0] == 'model':\n            name = k.replace('model.', '')\n            renamed_state_dict[name] = v\n    return renamed_state_dict\n\nclass VisLatentDynamicsModel(pl.LightningModule):\n\n    def __init__(self,\n                 lr: float=1e-4,\n                 seed: int=1,\n                 if_cuda: bool=True,\n                 if_test: bool=False,\n                 gamma: float=0.5,\n                 log_dir: str='logs',\n                 train_batch: int=512,\n                 val_batch: int=256,\n                 test_batch: int=256,\n                 num_workers: int=8,\n                 model_name: str='encoder-decoder-64',\n                 data_filepath: str='data',\n                 dataset: str='single_pendulum',\n                 lr_schedule: list=[20, 50, 100]) -> None:\n        super().__init__()\n        self.save_hyperparameters()\n        self.kwargs = {'num_workers': self.hparams.num_workers, 'pin_memory': True} if self.hparams.if_cuda else {}\n        # create visualization saving folder if testing\n        self.pred_log_dir = os.path.join(self.hparams.log_dir, 'predictions')\n        self.var_log_dir = os.path.join(self.hparams.log_dir, 'variables')\n        if not self.hparams.if_test:\n            mkdir(self.pred_log_dir)\n            mkdir(self.var_log_dir)\n\n        self.__build_model()\n\n    def __build_model(self):\n        # model\n        if self.hparams.model_name == 'latent-prediction' and self.hparams.dataset == 'single_pendulum':\n            self.model = LatentPredModel(in_channels=2)\n            self.high_dim_model = EncoderDecoder64x1x1(in_channels=3)\n            self.refine_model = RefineSinglePendulumModel(in_channels=64)\n        if self.hparams.model_name == 'latent-prediction' and self.hparams.dataset == 'double_pendulum':\n            self.model = LatentPredModel(in_channels=4)\n            self.high_dim_model = EncoderDecoder64x1x1(in_channels=3)\n            self.refine_model = RefineDoublePendulumModel(in_channels=64)\n        if self.hparams.model_name == 'latent-prediction' and self.hparams.dataset == 'elastic_pendulum':\n            self.model = LatentPredModel(in_channels=6)\n            self.high_dim_model = EncoderDecoder64x1x1(in_channels=3)\n            self.refine_model = RefineElasticPendulumModel(in_channels=64)\n        if self.hparams.model_name == 'latent-prediction' and self.hparams.dataset == 'swingstick_non_magnetic':\n            self.model = LatentPredModel(in_channels=4)\n            self.high_dim_model = EncoderDecoder64x1x1(in_channels=3)\n            self.refine_model = RefineSwingStickNonMagneticModel(in_channels=64)\n        if self.hparams.model_name == 'latent-prediction' and self.hparams.dataset == 'air_dancer':\n            self.model = LatentPredModel(in_channels=8)\n            self.high_dim_model = EncoderDecoder64x1x1(in_channels=3)\n            self.refine_model = RefineAirDancerModel(in_channels=64)\n        if self.hparams.model_name == 'latent-prediction' and self.hparams.dataset == 'lava_lamp':\n            self.model = LatentPredModel(in_channels=8)\n            self.high_dim_model = EncoderDecoder64x1x1(in_channels=3)\n            self.refine_model = RefineLavaLampModel(in_channels=64)\n        if self.hparams.model_name == 'latent-prediction' and self.hparams.dataset == 'fire':\n            self.model = LatentPredModel(in_channels=24)\n            self.high_dim_model = EncoderDecoder64x1x1(in_channels=3)\n            self.refine_model = RefineFireModel(in_channels=64)\n        if self.hparams.model_name == 'latent-prediction' and self.hparams.dataset == 'reaction_diffusion':\n            self.model = LatentPredModel(in_channels=2)\n            self.high_dim_model = EncoderDecoder64x1x1(in_channels=3)\n            self.refine_model = RefineReactionDiffusionModel(in_channels=64)\n        # loss\n        self.loss_func = nn.MSELoss()\n    \n    def load_model(self, checkpoint_filepath):\n        # load model for test\n        checkpoint_filepath = glob.glob(os.path.join(checkpoint_filepath, '*.ckpt'))[0]\n        ckpt = torch.load(checkpoint_filepath)\n        ckpt = rename_ckpt_for_multi_models(ckpt)\n        self.model.load_state_dict(ckpt)\n\n        for p in self.model.parameters():\n            p.requires_grad = False\n        self.model.eval()\n    \n    def load_high_dim_refine_model(self, high_dim_checkpoint_filepath, refine_checkpoint_filepath):\n        # load high-dim and refine models\n        high_dim_checkpoint_filepath = glob.glob(os.path.join(high_dim_checkpoint_filepath, '*.ckpt'))[0]\n        ckpt = torch.load(high_dim_checkpoint_filepath)\n        ckpt = rename_ckpt_for_multi_models(ckpt)\n        self.high_dim_model.load_state_dict(ckpt)\n\n        refine_checkpoint_filepath = glob.glob(os.path.join(refine_checkpoint_filepath, '*.ckpt'))[0]\n        ckpt = torch.load(refine_checkpoint_filepath)\n        ckpt = rename_ckpt_for_multi_models(ckpt)\n        self.refine_model.load_state_dict(ckpt)\n\n        for p in self.high_dim_model.parameters():\n            p.requires_grad = False\n        self.high_dim_model.eval()\n\n        for p in self.refine_model.parameters():\n            p.requires_grad = False\n        self.refine_model.eval()\n    \n    def extract_decoder_from_refine_model(self):\n        _layers = list(self.refine_model.children())[4:]\n        self.refine_model_decoder = torch.nn.Sequential(*_layers)\n\n        for p in self.refine_model_decoder.parameters():\n            p.requires_grad = False\n        self.refine_model_decoder.eval()\n\n    def data_to_state(self, data):\n        # (B, 3, 128, 256) -> (B, 64, 1, 1) -> (B, 64) -> (B, ID)\n        _, latent = self.high_dim_model(data, data, False)\n        latent = latent.squeeze(-1).squeeze(-1)\n        _, state = self.refine_model(latent)\n        return state\n\n    def training_step(self, batch, batch_idx):\n        data, target, filepath = batch\n        data_state = self.data_to_state(data)\n        target_state = self.data_to_state(target)\n        output_state = self.model(data_state)\n\n        train_loss = self.loss_func(output_state, target_state)\n        self.log('train_loss', train_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)\n        return train_loss\n\n    def validation_step(self, batch, batch_idx):\n        data, target, filepath = batch\n        data_state = self.data_to_state(data)\n        target_state = self.data_to_state(target)\n        output_state = self.model(data_state)\n\n        val_loss = self.loss_func(output_state, target_state)\n        self.log('val_loss', val_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)\n        return val_loss\n\n    def test_step(self, batch, batch_idx):\n        data, target, target_target, filepath = batch\n        data_state = self.data_to_state(data)\n        target_state = self.data_to_state(target)\n        output_state = self.model(data_state)\n        \n        latent_reconstructed = self.refine_model_decoder(output_state)\n        output, _ = self.high_dim_model(data, latent_reconstructed.unsqueeze(2).unsqueeze(3), True)\n\n        # calculate losses\n        test_loss = self.loss_func(output_state, target_state)\n        pixel_reconstruction_loss = self.loss_func(output, target_target)\n        self.log('test_loss', test_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)\n        self.log('pixel_reconstruction_loss', pixel_reconstruction_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)\n        \n        # save the output images\n        for idx in range(data.shape[0]):\n            comparison = torch.cat([data[idx, :, :, :128].unsqueeze(0),\n                                    data[idx, :, :, 128:].unsqueeze(0),\n                                    target[idx, :, :, :128].unsqueeze(0),\n                                    target[idx, :, :, 128:].unsqueeze(0),\n                                    target_target[idx, :, :, :128].unsqueeze(0),\n                                    target_target[idx, :, :, 128:].unsqueeze(0),\n                                    output[idx, :, :, :128].unsqueeze(0),\n                                    output[idx, :, :, 128:].unsqueeze(0)])\n            save_image(comparison.cpu(), os.path.join(self.pred_log_dir, filepath[idx]), nrow=1)\n        \n\n    def configure_optimizers(self):\n        optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.parameters()), lr=self.hparams.lr)\n        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=self.hparams.lr_schedule, gamma=self.hparams.gamma)\n        return [optimizer], [scheduler]\n\n    def setup(self, stage=None):\n        if stage == 'fit':\n            self.train_dataset = NeuralPhysDataset(data_filepath=self.hparams.data_filepath,\n                                                    flag='train',\n                                                    seed=self.hparams.seed,\n                                                    object_name=self.hparams.dataset)\n            self.val_dataset = NeuralPhysDataset(data_filepath=self.hparams.data_filepath,\n                                                    flag='val',\n                                                    seed=self.hparams.seed,\n                                                    object_name=self.hparams.dataset)\n\n        if stage == 'test':\n            self.test_dataset = NeuralPhysLatentDynamicsDataset(data_filepath=self.hparams.data_filepath,\n                                                                 flag='test',\n                                                                 seed=self.hparams.seed,\n                                                                 object_name=self.hparams.dataset)\n\n    def train_dataloader(self):\n        train_loader = torch.utils.data.DataLoader(dataset=self.train_dataset,\n                                                   batch_size=self.hparams.train_batch,\n                                                   shuffle=True,\n                                                   **self.kwargs)\n        return train_loader\n\n    def val_dataloader(self):\n        val_loader = torch.utils.data.DataLoader(dataset=self.val_dataset,\n                                                 batch_size=self.hparams.val_batch,\n                                                 shuffle=False,\n                                                 **self.kwargs)\n        return val_loader\n\n    def test_dataloader(self):\n        test_loader = torch.utils.data.DataLoader(dataset=self.test_dataset,\n                                                  batch_size=self.hparams.test_batch,\n                                                  shuffle=False,\n                                                  **self.kwargs)\n        return test_loader"
  },
  {
    "path": "pred.py",
    "content": "\n\"\"\"\nA. Long-term future prediction (model rollout)\n\n1. encoder-decoder (0, 1 -> 8192-dim latent -> 2', 3'):\n    - feed (2', 3') images as input to predict (4', 5') images ...\n2. encoder-decoder-64 (0, 1 -> 64-dim latent -> 2', 3'):\n    - feed (2', 3') images as input to predict (4', 5') images ...\n3. encoder-decoder-64 & refine-64 （0, 1 -> id-dim latent -> 2', 3')\n    - feed (2', 3') images as input to predict (4', 5') images ...\n4. encoder-decoder-64 & refine-64 hybrid:\n    - use refine-64 model at certain prediction steps\n\nB. Long-term future prediction with perturbation (model rollout)\n\"\"\"\n\n\nimport os\nimport sys\nimport glob\nimport yaml\nimport json\nimport torch\nimport pprint\nimport shutil\nimport numpy as np\nfrom PIL import Image\nfrom tqdm import tqdm\nfrom munch import munchify\nfrom torchvision import transforms\nfrom collections import OrderedDict\nfrom models import VisDynamicsModel\nfrom models_latentpred import VisLatentDynamicsModel\nfrom dataset import NeuralPhysDataset\nfrom pytorch_lightning import Trainer, seed_everything\nfrom pytorch_lightning.loggers import TensorBoardLogger\n\n\ndef mkdir(folder):\n    if os.path.exists(folder):\n        shutil.rmtree(folder)\n    os.makedirs(folder)\n\ndef load_config(filepath):\n    with open(filepath, 'r') as stream:\n        try:\n            trainer_params = yaml.safe_load(stream)\n            return trainer_params\n        except yaml.YAMLError as exc:\n            print(exc)\n\ndef seed(cfg):\n    torch.manual_seed(cfg.seed)\n    if cfg.if_cuda:\n        torch.cuda.manual_seed(cfg.seed)\n    # uncomment for strict reproducibility\n    # torch.set_deterministic(True)\n\n\ndef model_rollout():\n    config_filepath = str(sys.argv[2])\n    cfg = load_config(filepath=config_filepath)\n    pprint.pprint(cfg)\n    cfg = munchify(cfg)\n    seed(cfg)\n    seed_everything(cfg.seed)\n\n    log_dir = '_'.join([cfg.log_dir,\n                        cfg.dataset,\n                        cfg.model_name,\n                        str(cfg.seed)])\n\n    model = VisDynamicsModel(lr=cfg.lr,\n                             seed=cfg.seed,\n                             if_cuda=cfg.if_cuda,\n                             if_test=True,\n                             gamma=cfg.gamma,\n                             log_dir=log_dir,\n                             train_batch=cfg.train_batch,\n                             val_batch=cfg.val_batch,\n                             test_batch=cfg.test_batch,\n                             num_workers=cfg.num_workers,\n                             model_name=cfg.model_name,\n                             data_filepath=cfg.data_filepath,\n                             dataset=cfg.dataset,\n                             lr_schedule=cfg.lr_schedule)\n    \n    # load model\n    if cfg.model_name == 'encoder-decoder' or cfg.model_name == 'encoder-decoder-64':\n        checkpoint_filepath = str(sys.argv[3])\n        checkpoint_filepath = glob.glob(os.path.join(checkpoint_filepath, '*.ckpt'))[0]\n        ckpt = torch.load(checkpoint_filepath)\n        model.load_state_dict(ckpt['state_dict'])\n    \n    if 'refine' in cfg.model_name:\n        checkpoint_filepath = str(sys.argv[4])\n        checkpoint_filepath = glob.glob(os.path.join(checkpoint_filepath, '*.ckpt'))[0]\n        ckpt = torch.load(checkpoint_filepath)\n        ckpt = rename_ckpt_for_multi_models(ckpt)\n        model.model.load_state_dict(ckpt)\n\n        high_dim_checkpoint_filepath = str(sys.argv[3])\n        high_dim_checkpoint_filepath = glob.glob(os.path.join(high_dim_checkpoint_filepath, '*.ckpt'))[0]\n        ckpt = torch.load(high_dim_checkpoint_filepath)\n        ckpt = rename_ckpt_for_multi_models(ckpt)\n        model.high_dim_model.load_state_dict(ckpt)\n\n    model = model.to('cuda')\n    model.eval()\n    model.freeze()\n\n    # get all the test video ids\n    data_filepath_base = os.path.join(cfg.data_filepath, cfg.dataset)\n    with open(os.path.join('../datainfo', cfg.dataset, f'data_split_dict_{cfg.seed}.json'), 'r') as file:\n        seq_dict = json.load(file)\n    test_vid_ids = seq_dict['test']\n\n    pred_len = int(sys.argv[6])\n    long_term_folder = os.path.join(log_dir, 'prediction_long_term', 'model_rollout')\n    loss_dict = {}\n\n    if cfg.model_name == 'encoder-decoder' or cfg.model_name == 'encoder-decoder-64':\n        for p_vid_idx in tqdm(test_vid_ids):\n            vid_filepath = os.path.join(data_filepath_base, str(p_vid_idx))\n            total_num_frames = len(os.listdir(vid_filepath))\n            suf = os.listdir(vid_filepath)[0].split('.')[-1]\n            data = None\n            saved_folder = os.path.join(long_term_folder, str(p_vid_idx))\n            mkdir(saved_folder)\n            loss_lst = []\n            for start_frame_idx in range(total_num_frames - 3):\n                if start_frame_idx % 2 != 0:\n                    continue\n                # take the initial input from ground truth data\n                if start_frame_idx % pred_len == 0:\n                    data = [get_data(os.path.join(vid_filepath, f'{start_frame_idx}.{suf}')),\n                            get_data(os.path.join(vid_filepath, f'{start_frame_idx+1}.{suf}'))]\n                    data = (torch.cat(data, 2)).unsqueeze(0)\n                # get the target\n                target = [get_data(os.path.join(vid_filepath, f'{start_frame_idx+2}.{suf}')),\n                          get_data(os.path.join(vid_filepath, f'{start_frame_idx+3}.{suf}'))]\n                target = (torch.cat(target, 2)).unsqueeze(0)\n                # feed into the model\n                if cfg.model_name == 'encoder-decoder':\n                    output, latent = model.model(data.cuda())\n                if cfg.model_name == 'encoder-decoder-64':\n                    output, latent = model.model(data.cuda(), data.cuda(), False)\n\n                # compute loss\n                loss_lst.append(float(model.loss_func(output, target.cuda()).cpu().detach().numpy()))\n                \n                # save (2', 3'), (4', 5'), ...\n                img = tensor_to_img(output[0, :, :, :128])\n                img.save(os.path.join(saved_folder, f'{start_frame_idx+2}.{suf}'))\n                img = tensor_to_img(output[0, :, :, 128:])\n                img.save(os.path.join(saved_folder, f'{start_frame_idx+3}.{suf}'))\n\n                # the output becomes the input data in the next iteration\n                data = torch.tensor(output.cpu().detach().numpy()).float()\n\n            loss_dict[p_vid_idx] = loss_lst\n\n        # save the test loss for all the testing videos\n        with open(os.path.join(long_term_folder, 'test_loss.json'), 'w') as file:\n            json.dump(loss_dict, file, indent=4)\n    \n    if 'refine' in cfg.model_name:\n        for p_vid_idx in tqdm(test_vid_ids):\n            vid_filepath = os.path.join(data_filepath_base, str(p_vid_idx))\n            total_num_frames = len(os.listdir(vid_filepath))\n            suf = os.listdir(vid_filepath)[0].split('.')[-1]\n            data = None\n            saved_folder = os.path.join(long_term_folder, str(p_vid_idx))\n            mkdir(saved_folder)\n            loss_lst = []\n            for start_frame_idx in range(total_num_frames - 3):\n                if start_frame_idx % 2 != 0:\n                    continue\n                # take the initial input from ground truth data\n                if start_frame_idx % pred_len == 0:\n                    data = [get_data(os.path.join(vid_filepath, f'{start_frame_idx}.{suf}')),\n                            get_data(os.path.join(vid_filepath, f'{start_frame_idx+1}.{suf}'))]\n                    data = (torch.cat(data, 2)).unsqueeze(0)\n                # get the target\n                target = [get_data(os.path.join(vid_filepath, f'{start_frame_idx+2}.{suf}')),\n                          get_data(os.path.join(vid_filepath, f'{start_frame_idx+3}.{suf}'))]\n                target = (torch.cat(target, 2)).unsqueeze(0)\n                # feed into the model\n                _, latent = model.high_dim_model(data.cuda(), data.cuda(), False)\n                latent = latent.squeeze(-1).squeeze(-1)\n                latent_reconstructed, latent_latent = model.model(latent)\n                output, _ = model.high_dim_model(data.cuda(), latent_reconstructed.unsqueeze(2).unsqueeze(3), True)\n\n                # compute loss\n                loss_lst.append(float(model.loss_func(output, target.cuda()).cpu().detach().numpy()))\n\n                # save (2', 3'), (4', 5'), ...\n                img = tensor_to_img(output[0, :, :, :128])\n                img.save(os.path.join(saved_folder, f'{start_frame_idx+2}.{suf}'))\n                img = tensor_to_img(output[0, :, :, 128:])\n                img.save(os.path.join(saved_folder, f'{start_frame_idx+3}.{suf}'))\n\n                # the output becomes the input data in the next iteration\n                data = torch.tensor(output.cpu().detach().numpy()).float()\n\n            loss_dict[p_vid_idx] = loss_lst\n\n        # save the test loss for all the testing videos\n        with open(os.path.join(long_term_folder, 'test_loss.json'), 'w') as file:\n            json.dump(loss_dict, file, indent=4)\n\n\ndef model_rollout_hybrid(step):\n    config_filepath = str(sys.argv[2])\n    cfg = load_config(filepath=config_filepath)\n    pprint.pprint(cfg)\n    cfg = munchify(cfg)\n    seed(cfg)\n    seed_everything(cfg.seed)\n\n    if 'refine' not in cfg.model_name:\n        assert False, \"the hybrid scheme is only supported with refine model...\"\n\n    log_dir = '_'.join([cfg.log_dir,\n                        cfg.dataset,\n                        cfg.model_name,\n                        str(cfg.seed)])\n\n    model = VisDynamicsModel(lr=cfg.lr,\n                             seed=cfg.seed,\n                             if_cuda=cfg.if_cuda,\n                             if_test=True,\n                             gamma=cfg.gamma,\n                             log_dir=log_dir,\n                             train_batch=cfg.train_batch,\n                             val_batch=cfg.val_batch,\n                             test_batch=cfg.test_batch,\n                             num_workers=cfg.num_workers,\n                             model_name=cfg.model_name,\n                             data_filepath=cfg.data_filepath,\n                             dataset=cfg.dataset,\n                             lr_schedule=cfg.lr_schedule)\n\n    # load model\n    checkpoint_filepath = str(sys.argv[4])\n    checkpoint_filepath = glob.glob(os.path.join(checkpoint_filepath, '*.ckpt'))[0]\n    ckpt = torch.load(checkpoint_filepath)\n    ckpt = rename_ckpt_for_multi_models(ckpt)\n    model.model.load_state_dict(ckpt)\n\n    high_dim_checkpoint_filepath = str(sys.argv[3])\n    high_dim_checkpoint_filepath = glob.glob(os.path.join(high_dim_checkpoint_filepath, '*.ckpt'))[0]\n    ckpt = torch.load(high_dim_checkpoint_filepath)\n    ckpt = rename_ckpt_for_multi_models(ckpt)\n    model.high_dim_model.load_state_dict(ckpt)\n\n    model = model.to('cuda')\n    model.eval()\n    model.freeze()\n\n    # get all the test video ids\n    data_filepath_base = os.path.join(cfg.data_filepath, cfg.dataset)\n    with open(os.path.join('../datainfo', cfg.dataset, f'data_split_dict_{cfg.seed}.json'), 'r') as file:\n        seq_dict = json.load(file)\n    test_vid_ids = seq_dict['test']\n\n    pred_len = int(sys.argv[6])\n    long_term_folder = os.path.join(log_dir, 'prediction_long_term', f'hybrid_rollout_{step}')\n    loss_dict = {}\n\n    for p_vid_idx in tqdm(test_vid_ids):\n        vid_filepath = os.path.join(data_filepath_base, str(p_vid_idx))\n        total_num_frames = len(os.listdir(vid_filepath))\n        suf = os.listdir(vid_filepath)[0].split('.')[-1]\n        data = None\n        saved_folder = os.path.join(long_term_folder, str(p_vid_idx))\n        mkdir(saved_folder)\n        loss_lst = []\n        for start_frame_idx in range(total_num_frames - 3):\n            if start_frame_idx % 2 != 0:\n                continue\n            # take the initial input from ground truth data\n            if start_frame_idx % pred_len == 0:\n                data = [get_data(os.path.join(vid_filepath, f'{start_frame_idx}.{suf}')),\n                        get_data(os.path.join(vid_filepath, f'{start_frame_idx+1}.{suf}'))]\n                data = (torch.cat(data, 2)).unsqueeze(0)\n            # get the target\n            target = [get_data(os.path.join(vid_filepath, f'{start_frame_idx+2}.{suf}')),\n                      get_data(os.path.join(vid_filepath, f'{start_frame_idx+3}.{suf}'))]\n            target = (torch.cat(target, 2)).unsqueeze(0)\n            # feed into the model\n            if (start_frame_idx + 2) % (2 * step + 2) == 0:\n                _, latent = model.high_dim_model(data.cuda(), data.cuda(), False)\n                latent = latent.squeeze(-1).squeeze(-1)\n                latent_reconstructed, latent_latent = model.model(latent)\n                output, _ = model.high_dim_model(data.cuda(), latent_reconstructed.unsqueeze(2).unsqueeze(3), True)\n            else:\n                output, _ = model.high_dim_model(data.cuda(), data.cuda(), False)\n\n            # compute loss\n            loss_lst.append(float(model.loss_func(output, target.cuda()).cpu().detach().numpy()))\n\n            # save (2', 3'), (4', 5'), ...\n            img = tensor_to_img(output[0, :, :, :128])\n            img.save(os.path.join(saved_folder, f'{start_frame_idx+2}.{suf}'))\n            img = tensor_to_img(output[0, :, :, 128:])\n            img.save(os.path.join(saved_folder, f'{start_frame_idx+3}.{suf}'))\n\n            # the output becomes the input data in the next iteration\n            data = torch.tensor(output.cpu().detach().numpy()).float()\n\n        loss_dict[p_vid_idx] = loss_lst\n\n    # save the test loss for all the testing videos\n    with open(os.path.join(long_term_folder, 'test_loss.json'), 'w') as file:\n        json.dump(loss_dict, file, indent=4)\n\n\ndef model_rollout_perturb(perturb_type, perturb_level):\n    config_filepath = str(sys.argv[2])\n    cfg = load_config(filepath=config_filepath)\n    pprint.pprint(cfg)\n    cfg = munchify(cfg)\n    seed(cfg)\n    seed_everything(cfg.seed)\n\n    log_dir = '_'.join([cfg.log_dir,\n                        cfg.dataset,\n                        cfg.model_name,\n                        str(cfg.seed)])\n\n    model = VisDynamicsModel(lr=cfg.lr,\n                             seed=cfg.seed,\n                             if_cuda=cfg.if_cuda,\n                             if_test=True,\n                             gamma=cfg.gamma,\n                             log_dir=log_dir,\n                             train_batch=cfg.train_batch,\n                             val_batch=cfg.val_batch,\n                             test_batch=cfg.test_batch,\n                             num_workers=cfg.num_workers,\n                             model_name=cfg.model_name,\n                             data_filepath=cfg.data_filepath,\n                             dataset=cfg.dataset,\n                             lr_schedule=cfg.lr_schedule)\n    \n    # load model\n    if cfg.model_name == 'encoder-decoder' or cfg.model_name == 'encoder-decoder-64':\n        checkpoint_filepath = str(sys.argv[3])\n        checkpoint_filepath = glob.glob(os.path.join(checkpoint_filepath, '*.ckpt'))[0]\n        ckpt = torch.load(checkpoint_filepath)\n        model.load_state_dict(ckpt['state_dict'])\n    \n    if 'refine' in cfg.model_name:\n        checkpoint_filepath = str(sys.argv[4])\n        checkpoint_filepath = glob.glob(os.path.join(checkpoint_filepath, '*.ckpt'))[0]\n        ckpt = torch.load(checkpoint_filepath)\n        ckpt = rename_ckpt_for_multi_models(ckpt)\n        model.model.load_state_dict(ckpt)\n\n        high_dim_checkpoint_filepath = str(sys.argv[3])\n        high_dim_checkpoint_filepath = glob.glob(os.path.join(high_dim_checkpoint_filepath, '*.ckpt'))[0]\n        ckpt = torch.load(high_dim_checkpoint_filepath)\n        ckpt = rename_ckpt_for_multi_models(ckpt)\n        model.high_dim_model.load_state_dict(ckpt)\n\n    model = model.to('cuda')\n    model.eval()\n    model.freeze()\n\n    # get all the test video ids\n    data_filepath_base = os.path.join(cfg.data_filepath, cfg.dataset)\n    with open(os.path.join('../datainfo', cfg.dataset, f'data_split_dict_{cfg.seed}.json'), 'r') as file:\n        seq_dict = json.load(file)\n    test_vid_ids = seq_dict['test']\n\n    pred_len = int(sys.argv[6])\n    long_term_folder = os.path.join(log_dir, 'prediction_long_term', f'model_rollout_perturb_{perturb_type}_{perturb_level}')\n    loss_dict = {}\n\n    if cfg.model_name == 'encoder-decoder' or cfg.model_name == 'encoder-decoder-64':\n        for p_vid_idx in tqdm(test_vid_ids):\n            vid_filepath = os.path.join(data_filepath_base, str(p_vid_idx))\n            total_num_frames = len(os.listdir(vid_filepath))\n            suf = os.listdir(vid_filepath)[0].split('.')[-1]\n            data = None\n            saved_folder = os.path.join(long_term_folder, str(p_vid_idx))\n            mkdir(saved_folder)\n            loss_lst = []\n            for start_frame_idx in range(total_num_frames - 3):\n                if start_frame_idx % 2 != 0:\n                    continue\n                # take the initial input from ground truth data\n                if start_frame_idx % pred_len == 0:\n                    data = [get_data_perturb(os.path.join(vid_filepath, f'{start_frame_idx}.{suf}'), perturb_type, perturb_level),\n                            get_data_perturb(os.path.join(vid_filepath, f'{start_frame_idx+1}.{suf}'), perturb_type, perturb_level)]\n                    data = (torch.cat(data, 2)).unsqueeze(0)\n                    img = tensor_to_img(data[0, :, :, :128])\n                    img.save(os.path.join(saved_folder, f'{start_frame_idx}.{suf}'))\n                    img = tensor_to_img(data[0, :, :, 128:])\n                    img.save(os.path.join(saved_folder, f'{start_frame_idx+1}.{suf}'))\n                # get the target\n                target = [get_data(os.path.join(vid_filepath, f'{start_frame_idx+2}.{suf}')),\n                          get_data(os.path.join(vid_filepath, f'{start_frame_idx+3}.{suf}'))]\n                target = (torch.cat(target, 2)).unsqueeze(0)\n                # feed into the model\n                if cfg.model_name == 'encoder-decoder':\n                    output, latent = model.model(data.cuda())\n                if cfg.model_name == 'encoder-decoder-64':\n                    output, latent = model.model(data.cuda(), data.cuda(), False)\n\n                # compute loss\n                loss_lst.append(float(model.loss_func(output, target.cuda()).cpu().detach().numpy()))\n                \n                # save (2', 3'), (4', 5'), ...\n                img = tensor_to_img(output[0, :, :, :128])\n                img.save(os.path.join(saved_folder, f'{start_frame_idx+2}.{suf}'))\n                img = tensor_to_img(output[0, :, :, 128:])\n                img.save(os.path.join(saved_folder, f'{start_frame_idx+3}.{suf}'))\n\n                # the output becomes the input data in the next iteration\n                data = torch.tensor(output.cpu().detach().numpy()).float()\n\n            loss_dict[p_vid_idx] = loss_lst\n\n        # save the test loss for all the testing videos\n        with open(os.path.join(long_term_folder, 'test_loss.json'), 'w') as file:\n            json.dump(loss_dict, file, indent=4)\n\n    if 'refine' in cfg.model_name:\n        for p_vid_idx in tqdm(test_vid_ids):\n            vid_filepath = os.path.join(data_filepath_base, str(p_vid_idx))\n            total_num_frames = len(os.listdir(vid_filepath))\n            suf = os.listdir(vid_filepath)[0].split('.')[-1]\n            data = None\n            saved_folder = os.path.join(long_term_folder, str(p_vid_idx))\n            mkdir(saved_folder)\n            loss_lst = []\n            for start_frame_idx in range(total_num_frames - 3):\n                if start_frame_idx % 2 != 0:\n                    continue\n                # take the initial input from ground truth data\n                if start_frame_idx % pred_len == 0:\n                    data = [get_data_perturb(os.path.join(vid_filepath, f'{start_frame_idx}.{suf}'), perturb_type, perturb_level),\n                            get_data_perturb(os.path.join(vid_filepath, f'{start_frame_idx+1}.{suf}'), perturb_type, perturb_level)]\n                    data = (torch.cat(data, 2)).unsqueeze(0)\n                    img = tensor_to_img(data[0, :, :, :128])\n                    img.save(os.path.join(saved_folder, f'{start_frame_idx}.{suf}'))\n                    img = tensor_to_img(data[0, :, :, 128:])\n                    img.save(os.path.join(saved_folder, f'{start_frame_idx+1}.{suf}'))\n                # get the target\n                target = [get_data(os.path.join(vid_filepath, f'{start_frame_idx+2}.{suf}')),\n                          get_data(os.path.join(vid_filepath, f'{start_frame_idx+3}.{suf}'))]\n                target = (torch.cat(target, 2)).unsqueeze(0)\n                # feed into the model\n                _, latent = model.high_dim_model(data.cuda(), data.cuda(), False)\n                latent = latent.squeeze(-1).squeeze(-1)\n                latent_reconstructed, latent_latent = model.model(latent)\n                output, _ = model.high_dim_model(data.cuda(), latent_reconstructed.unsqueeze(2).unsqueeze(3), True)\n\n                # compute loss\n                loss_lst.append(float(model.loss_func(output, target.cuda()).cpu().detach().numpy()))\n\n                # save (2', 3'), (4', 5'), ...\n                img = tensor_to_img(output[0, :, :, :128])\n                img.save(os.path.join(saved_folder, f'{start_frame_idx+2}.{suf}'))\n                img = tensor_to_img(output[0, :, :, 128:])\n                img.save(os.path.join(saved_folder, f'{start_frame_idx+3}.{suf}'))\n\n                # the output becomes the input data in the next iteration\n                data = torch.tensor(output.cpu().detach().numpy()).float()\n\n            loss_dict[p_vid_idx] = loss_lst\n\n        # save the test loss for all the testing videos\n        with open(os.path.join(long_term_folder, 'test_loss.json'), 'w') as file:\n            json.dump(loss_dict, file, indent=4)\n\n\ndef rename_ckpt_for_multi_models(ckpt):\n    renamed_state_dict = OrderedDict()\n    for k, v in ckpt['state_dict'].items():\n        if 'high_dim_model' in k:\n            name = k.replace('high_dim_model.', '')\n        else:\n            name = k.replace('model.', '')\n        renamed_state_dict[name] = v\n    return renamed_state_dict\n\ndef get_data(filepath):\n    data = Image.open(filepath)\n    data = data.resize((128, 128))\n    data = np.array(data)\n    data = torch.tensor(data / 255.0)\n    data = data.permute(2, 0, 1).float()\n    return data\n\ndef get_data_perturb(filepath, perturb_type, perturb_level):\n    data = Image.open(filepath)\n    data = data.resize((128, 128))\n    data = np.array(data)\n    \n    bg_color = np.array([215, 205, 192])\n    rng = np.random.RandomState(int(filepath.split('/')[-2]))\n    new_bg_color = rng.randint(256, size=3)\n\n    if perturb_type == 'background_replace':\n        for i in range(2**(perturb_level-1)):\n            for j in range(2**(perturb_level-1)):\n                if np.array_equal(data[i, j], bg_color): \n                    data[i, j] = new_bg_color\n    elif perturb_type == 'background_cover':\n        for i in range(2**(perturb_level-1)):\n            for j in range(2**(perturb_level-1)):\n                data[i, j] = new_bg_color\n    elif perturb_type == 'white_noise':\n        sigma = 255.0 * (2**(perturb_level-1) / 128) ** 2\n        data = data + rng.normal(0, sigma, data.shape)\n    else:\n        pass\n\n    data = torch.tensor(data / 255.0)\n    data = data.permute(2, 0, 1).float()\n    return data\n\n# out_tensor: 3 x 128 x 128 -> 128 x 128 x 3\ndef tensor_to_img(out_tensor):\n    return transforms.ToPILImage()(out_tensor).convert(\"RGB\")\n\n\nif __name__ == '__main__':\n    if str(sys.argv[1]) == 'model-rollout':\n        model_rollout()\n    elif 'hybrid' in str(sys.argv[1]):\n        step = int(sys.argv[1].split('-')[-1])\n        model_rollout_hybrid(step)\n    elif str(sys.argv[1]) == 'latent-prediction':\n        latent_prediction()\n    elif 'perturb' in str(sys.argv[1]):\n        perturb_type = str(sys.argv[1].split('-')[-2])\n        perturb_level = int(sys.argv[1].split('-')[-1])\n        model_rollout_perturb(perturb_type, perturb_level)\n    else:\n        assert False, \"prediction scheme is not supported...\""
  },
  {
    "path": "requirements.txt",
    "content": "absl-py==0.7.1\naiohttp==3.7.4.post0\naiohttp-cors==0.7.0\naioredis==1.3.1\nastor==0.7.1\nastunparse==1.6.3\nasync-timeout==3.0.1\natari-py==0.1.15\natomicwrites==1.3.0\nattrs==19.3.0\naudioread==2.1.6\nbackcall==0.1.0\nbackports.shutil-get-terminal-size==1.0.0\nbleach==1.5.0\nblessings==1.7\ncachetools==4.1.1\ncattrs==1.0.0\ncertifi==2019.3.9\nchardet==3.0.4\nchumpy==0.70\nclick==8.0.1\ncloudpickle==0.8.1\ncolorama==0.4.4\ncontextvars==2.4\ncycler==0.10.0\nCython==0.29.22\ndataclasses==0.8\ndecorator==4.4.0\ndeepdish==0.3.6\ndefusedxml==0.6.0\ndill==0.2.9\nDjango==2.2.1\ndocopt==0.6.2\nentrypoints==0.3\nfilelock==3.0.12\nFlask==2.0.1\nflatbuffers==1.12\nfsspec==0.8.4\nfuture==0.17.1\ngast==0.3.3\nglob2==0.6\ngoogle-api-core==1.28.0\ngoogle-auth==1.30.1\ngoogle-auth-oauthlib==0.4.1\ngoogle-pasta==0.2.0\ngoogleapis-common-protos==1.53.0\ngpustat==0.6.0\ngrpcio==1.34.0\ngym==0.12.5\nh5py==2.10.0\nhiredis==2.0.0\nhtml5lib==0.9999999\nidna==2.8\nidna-ssl==1.1.0\nimage==1.5.25\nimageio==2.5.0\nimmutables==0.15\nimportlib-metadata==4.3.0\nimutils==0.5.2\nipdb==0.12\nipykernel==5.1.0\nipython==7.4.0\nipython-genutils==0.2.0\nipywidgets==7.4.2\nitsdangerous==2.0.1\njedi==0.13.3\nJinja2==3.0.1\njoblib==0.13.2\njsonschema==3.0.1\njupyter==1.0.0\njupyter-client==5.2.4\njupyter-console==6.0.0\njupyter-core==4.4.0\nkaleido==0.2.1\nKeras-Preprocessing==1.1.2\nkiwisolver==1.0.1\nlibrosa==0.6.3\nllvmlite==0.28.0\nMarkdown==3.1\nMarkupSafe==2.0.1\nmatplotlib==2.2.2\nmistune==0.8.4\nmock==3.0.5\nmore-itertools==7.0.0\nmpi4py==3.0.3\nmsgpack==1.0.2\nmultidict==5.1.0\nmunch==2.5.0\nmunkres==1.0.12\nnbconvert==5.4.1\nnbformat==4.4.0\nnetworkx==2.3\nnibabel==2.4.0\nnotebook==5.7.8\nnumba==0.43.1\nnumexpr==2.6.9\nnumpy==1.19.4\nnvidia-ml-py3==7.352.0\noauthlib==3.1.0\nopencensus==0.7.13\nopencensus-context==0.1.2\nopencv-python==4.5.2.52\nopt-einsum==3.3.0\npackaging==20.9\npandas==0.24.2\npandocfilters==1.4.2\nparso==0.3.4\npexpect==4.6.0\npickleshare==0.7.5\nPillow==7.1.2\npkg-resources==0.0.0\nplotly==4.14.3\npluggy==0.11.0\nprometheus-client==0.10.1\nprompt-toolkit==2.0.9\nprotobuf==3.17.1\npsutil==5.8.0\nptyprocess==0.6.0\npy==1.8.0\npy-spy==0.3.7\npyasn1==0.4.8\npyasn1-modules==0.2.8\npyglet==1.3.2\nPygments==2.3.1\nPyOpenGL==3.1.0\npyparsing==2.3.1\npyrsistent==0.14.11\npytest==3.10.1\npython-dateutil==2.8.0\npytorch-lightning==1.0.8\npytz==2019.1\nPyWavelets==1.0.3\nPyYAML==5.1\npyzmq==18.0.1\nqtconsole==4.4.3\nray==1.3.0\nredis==3.5.3\nreprint==0.5.2\nrequests==2.21.0\nrequests-oauthlib==1.3.0\nresampy==0.2.1\nretrying==1.3.3\nrsa==4.6\nscikit-image==0.15.0\nscikit-learn==0.20.3\nscipy==1.4.1\nseaborn==0.9.0\nSend2Trash==1.5.0\nShapely==1.6.4.post2\nsix==1.16.0\nsklearn==0.0\nsqlparse==0.3.0\nstable-baselines==2.5.1\ntables==3.5.1\ntabulate==0.8.9\ntensorboard==2.2.2\ntensorboard-plugin-wit==1.7.0\ntensorboardX==2.2\ntermcolor==1.1.0\nterminado==0.8.2\ntestpath==0.4.2\n--find-links https://download.pytorch.org/whl/torch_stable.html\ntorch==1.7.0\ntorchfile==0.1.0\ntorchvision==0.8.1\ntornado==6.0.2\ntqdm==4.54.1\ntraitlets==4.3.2\ntyping-extensions==3.7.4.3\nurllib3==1.24.1\nvisdom==0.1.8.8\nwcwidth==0.1.7\nwebencodings==0.5.1\nwebsocket-client==0.56.0\nWerkzeug==2.0.1\nwidgetsnbextension==3.4.2\nwrapt==1.12.1\nyarl==1.6.3\nzipp==3.4.1\nzmq==0.0.0"
  },
  {
    "path": "scripts/encoder_decoder_64_eval.sh",
    "content": "#!/bin/bash\n\ndataset=$1\ngpu=$2\n\necho \"=============================================================================================\"\necho \"============== Evaluating encoder-decoder-64 model on: $dataset (gpu id: $gpu) ==============\"\necho \"=============================================================================================\"\n\nscreen -S eval-\"$dataset\"-64 -dm bash -c \"CUDA_VISIBLE_DEVICES=\"$gpu\" python ../eval.py ../configs/\"$dataset\"/model64/config1.yaml ./logs_\"$dataset\"_encoder-decoder-64_1/lightning_logs/checkpoints NA eval-eval NA; \\\n                                          CUDA_VISIBLE_DEVICES=\"$gpu\" python ../eval.py ../configs/\"$dataset\"/model64/config2.yaml ./logs_\"$dataset\"_encoder-decoder-64_2/lightning_logs/checkpoints NA eval-eval NA; \\\n                                          CUDA_VISIBLE_DEVICES=\"$gpu\" python ../eval.py ../configs/\"$dataset\"/model64/config3.yaml ./logs_\"$dataset\"_encoder-decoder-64_3/lightning_logs/checkpoints NA eval-eval NA; \\\n                                          exec sh\";"
  },
  {
    "path": "scripts/encoder_decoder_64_eval_gather.sh",
    "content": "#!/bin/bash\n\ndataset=$1\ngpu=$2\n\necho \"=========================================================================================================\"\necho \"============== Evaluating (Gathering) encoder-decoder-64 model on: $dataset (gpu id: $gpu) ==============\"\necho \"=========================================================================================================\"\n\nscreen -S eval-\"$dataset\"-64 -dm bash -c \"CUDA_VISIBLE_DEVICES=\"$gpu\" python ../eval.py ../configs/\"$dataset\"/model64/config1.yaml ./logs_\"$dataset\"_encoder-decoder-64_1/lightning_logs/checkpoints NA eval-train NA; \\\n                                          CUDA_VISIBLE_DEVICES=\"$gpu\" python ../eval.py ../configs/\"$dataset\"/model64/config2.yaml ./logs_\"$dataset\"_encoder-decoder-64_2/lightning_logs/checkpoints NA eval-train NA; \\\n                                          CUDA_VISIBLE_DEVICES=\"$gpu\" python ../eval.py ../configs/\"$dataset\"/model64/config3.yaml ./logs_\"$dataset\"_encoder-decoder-64_3/lightning_logs/checkpoints NA eval-train NA; \\\n                                          exec sh\";"
  },
  {
    "path": "scripts/encoder_decoder_64_long_term_model_rollout.sh",
    "content": "#!/bin/bash\n\ndataset=$1\ngpu=$2\n\necho \"==========================================================================================================\"\necho \"============== Long-term model rollout encoder-decoder-64 model on: $dataset (gpu id: $gpu) ==============\"\necho \"==========================================================================================================\"\n\nscreen -S eval-\"$dataset\"-longterm-modelrollout -dm bash -c \"CUDA_VISIBLE_DEVICES=\"$gpu\" python ../pred.py model-rollout ../configs/\"$dataset\"/model64/config1.yaml ./logs_\"$dataset\"_encoder-decoder-64_1/lightning_logs/checkpoints NA NA 60; \\\n                                                             CUDA_VISIBLE_DEVICES=\"$gpu\" python ../pred.py model-rollout ../configs/\"$dataset\"/model64/config2.yaml ./logs_\"$dataset\"_encoder-decoder-64_2/lightning_logs/checkpoints NA NA 60; \\\n                                                             CUDA_VISIBLE_DEVICES=\"$gpu\" python ../pred.py model-rollout ../configs/\"$dataset\"/model64/config3.yaml ./logs_\"$dataset\"_encoder-decoder-64_3/lightning_logs/checkpoints NA NA 60; \\\n                                                             exec sh\";"
  },
  {
    "path": "scripts/encoder_decoder_64_long_term_model_rollout_perturb_all.sh",
    "content": "#!/bin/bash\n\ndataset=$1\ngpu=$2\n\necho \"==========================================================================================================\"\necho \"============== Long-term model rollout encoder-decoder-64 model on: $dataset (gpu id: $gpu) ==============\"\necho \"==========================================================================================================\"\n\nscreen -S eval-\"$dataset\"-longterm-modelrollout -dm bash -c \"CUDA_VISIBLE_DEVICES=\"$gpu\" python ../pred.py model-rollout-perturb-background_cover-5 ../configs/\"$dataset\"/model64/config1.yaml ./logs_\"$dataset\"_encoder-decoder-64_1/lightning_logs/checkpoints NA NA 60; \\\n                                                             CUDA_VISIBLE_DEVICES=\"$gpu\" python ../pred.py model-rollout-perturb-background_cover-5 ../configs/\"$dataset\"/model64/config2.yaml ./logs_\"$dataset\"_encoder-decoder-64_2/lightning_logs/checkpoints NA NA 60; \\\n                                                             CUDA_VISIBLE_DEVICES=\"$gpu\" python ../pred.py model-rollout-perturb-background_cover-5 ../configs/\"$dataset\"/model64/config3.yaml ./logs_\"$dataset\"_encoder-decoder-64_3/lightning_logs/checkpoints NA NA 60; \\\n                                                             \n                                                             CUDA_VISIBLE_DEVICES=\"$gpu\" python ../pred.py model-rollout-perturb-background_cover-7 ../configs/\"$dataset\"/model64/config1.yaml ./logs_\"$dataset\"_encoder-decoder-64_1/lightning_logs/checkpoints NA NA 60; \\\n                                                             CUDA_VISIBLE_DEVICES=\"$gpu\" python ../pred.py model-rollout-perturb-background_cover-7 ../configs/\"$dataset\"/model64/config2.yaml ./logs_\"$dataset\"_encoder-decoder-64_2/lightning_logs/checkpoints NA NA 60; \\\n                                                             CUDA_VISIBLE_DEVICES=\"$gpu\" python ../pred.py model-rollout-perturb-background_cover-7 ../configs/\"$dataset\"/model64/config3.yaml ./logs_\"$dataset\"_encoder-decoder-64_3/lightning_logs/checkpoints NA NA 60; \\\n                                                             \n                                                             CUDA_VISIBLE_DEVICES=\"$gpu\" python ../pred.py model-rollout-perturb-background_cover-8 ../configs/\"$dataset\"/model64/config1.yaml ./logs_\"$dataset\"_encoder-decoder-64_1/lightning_logs/checkpoints NA NA 60; \\\n                                                             CUDA_VISIBLE_DEVICES=\"$gpu\" python ../pred.py model-rollout-perturb-background_cover-8 ../configs/\"$dataset\"/model64/config2.yaml ./logs_\"$dataset\"_encoder-decoder-64_2/lightning_logs/checkpoints NA NA 60; \\\n                                                             CUDA_VISIBLE_DEVICES=\"$gpu\" python ../pred.py model-rollout-perturb-background_cover-8 ../configs/\"$dataset\"/model64/config3.yaml ./logs_\"$dataset\"_encoder-decoder-64_3/lightning_logs/checkpoints NA NA 60; \\\n                                                             \n                                                             CUDA_VISIBLE_DEVICES=\"$gpu\" python ../pred.py model-rollout-perturb-background_replace-5 ../configs/\"$dataset\"/model64/config1.yaml ./logs_\"$dataset\"_encoder-decoder-64_1/lightning_logs/checkpoints NA NA 60; \\\n                                                             CUDA_VISIBLE_DEVICES=\"$gpu\" python ../pred.py model-rollout-perturb-background_replace-5 ../configs/\"$dataset\"/model64/config2.yaml ./logs_\"$dataset\"_encoder-decoder-64_2/lightning_logs/checkpoints NA NA 60; \\\n                                                             CUDA_VISIBLE_DEVICES=\"$gpu\" python ../pred.py model-rollout-perturb-background_replace-5 ../configs/\"$dataset\"/model64/config3.yaml ./logs_\"$dataset\"_encoder-decoder-64_3/lightning_logs/checkpoints NA NA 60; \\\n                                                             \n                                                             CUDA_VISIBLE_DEVICES=\"$gpu\" python ../pred.py model-rollout-perturb-background_replace-7 ../configs/\"$dataset\"/model64/config1.yaml ./logs_\"$dataset\"_encoder-decoder-64_1/lightning_logs/checkpoints NA NA 60; \\\n                                                             CUDA_VISIBLE_DEVICES=\"$gpu\" python ../pred.py model-rollout-perturb-background_replace-7 ../configs/\"$dataset\"/model64/config2.yaml ./logs_\"$dataset\"_encoder-decoder-64_2/lightning_logs/checkpoints NA NA 60; \\\n                                                             CUDA_VISIBLE_DEVICES=\"$gpu\" python ../pred.py model-rollout-perturb-background_replace-7 ../configs/\"$dataset\"/model64/config3.yaml ./logs_\"$dataset\"_encoder-decoder-64_3/lightning_logs/checkpoints NA NA 60; \\\n                                                             \n                                                             CUDA_VISIBLE_DEVICES=\"$gpu\" python ../pred.py model-rollout-perturb-background_replace-8 ../configs/\"$dataset\"/model64/config1.yaml ./logs_\"$dataset\"_encoder-decoder-64_1/lightning_logs/checkpoints NA NA 60; \\\n                                                             CUDA_VISIBLE_DEVICES=\"$gpu\" python ../pred.py model-rollout-perturb-background_replace-8 ../configs/\"$dataset\"/model64/config2.yaml ./logs_\"$dataset\"_encoder-decoder-64_2/lightning_logs/checkpoints NA NA 60; \\\n                                                             CUDA_VISIBLE_DEVICES=\"$gpu\" python ../pred.py model-rollout-perturb-background_replace-8 ../configs/\"$dataset\"/model64/config3.yaml ./logs_\"$dataset\"_encoder-decoder-64_3/lightning_logs/checkpoints NA NA 60; \\\n                                                             \n                                                             CUDA_VISIBLE_DEVICES=\"$gpu\" python ../pred.py model-rollout-perturb-white_noise-5 ../configs/\"$dataset\"/model64/config1.yaml ./logs_\"$dataset\"_encoder-decoder-64_1/lightning_logs/checkpoints NA NA 60; \\\n                                                             CUDA_VISIBLE_DEVICES=\"$gpu\" python ../pred.py model-rollout-perturb-white_noise-5 ../configs/\"$dataset\"/model64/config2.yaml ./logs_\"$dataset\"_encoder-decoder-64_2/lightning_logs/checkpoints NA NA 60; \\\n                                                             CUDA_VISIBLE_DEVICES=\"$gpu\" python ../pred.py model-rollout-perturb-white_noise-5 ../configs/\"$dataset\"/model64/config3.yaml ./logs_\"$dataset\"_encoder-decoder-64_3/lightning_logs/checkpoints NA NA 60; \\\n                                                             \n                                                             CUDA_VISIBLE_DEVICES=\"$gpu\" python ../pred.py model-rollout-perturb-white_noise-7 ../configs/\"$dataset\"/model64/config1.yaml ./logs_\"$dataset\"_encoder-decoder-64_1/lightning_logs/checkpoints NA NA 60; \\\n                                                             CUDA_VISIBLE_DEVICES=\"$gpu\" python ../pred.py model-rollout-perturb-white_noise-7 ../configs/\"$dataset\"/model64/config2.yaml ./logs_\"$dataset\"_encoder-decoder-64_2/lightning_logs/checkpoints NA NA 60; \\\n                                                             CUDA_VISIBLE_DEVICES=\"$gpu\" python ../pred.py model-rollout-perturb-white_noise-7 ../configs/\"$dataset\"/model64/config3.yaml ./logs_\"$dataset\"_encoder-decoder-64_3/lightning_logs/checkpoints NA NA 60; \\\n                                                             \n                                                             CUDA_VISIBLE_DEVICES=\"$gpu\" python ../pred.py model-rollout-perturb-white_noise-8 ../configs/\"$dataset\"/model64/config1.yaml ./logs_\"$dataset\"_encoder-decoder-64_1/lightning_logs/checkpoints NA NA 60; \\\n                                                             CUDA_VISIBLE_DEVICES=\"$gpu\" python ../pred.py model-rollout-perturb-white_noise-8 ../configs/\"$dataset\"/model64/config2.yaml ./logs_\"$dataset\"_encoder-decoder-64_2/lightning_logs/checkpoints NA NA 60; \\\n                                                             CUDA_VISIBLE_DEVICES=\"$gpu\" python ../pred.py model-rollout-perturb-white_noise-8 ../configs/\"$dataset\"/model64/config3.yaml ./logs_\"$dataset\"_encoder-decoder-64_3/lightning_logs/checkpoints NA NA 60; \\\n                                                             \n                                                             exec sh\";"
  },
  {
    "path": "scripts/encoder_decoder_64_train.sh",
    "content": "#!/bin/bash\n\ndataset=$1\ngpu=$2\n\necho \"===========================================================================================\"\necho \"============== Training encoder-decoder-64 model on: $dataset (gpu id: $gpu) ==============\"\necho \"===========================================================================================\"\n\nscreen -S train-\"$dataset\"-64 -dm bash -c \"CUDA_VISIBLE_DEVICES=\"$gpu\" python ../main.py ../configs/\"$dataset\"/model64/config1.yaml; \\\n                                           CUDA_VISIBLE_DEVICES=\"$gpu\" python ../main.py ../configs/\"$dataset\"/model64/config2.yaml; \\\n                                           CUDA_VISIBLE_DEVICES=\"$gpu\" python ../main.py ../configs/\"$dataset\"/model64/config3.yaml; \\\n                                           exec sh\";"
  },
  {
    "path": "scripts/encoder_decoder_estimate_dimension.sh",
    "content": "#!/bin/bash\n\ndataset=$1\n\necho \"============================================================================================\"\necho \"========== Estimating intrinsic dimension from encoder-decoder model on: $dataset ==========\"\necho \"============================================================================================\"\n\nscreen -S eval-\"$dataset\"-dimension -dm bash -c \"OMP_NUM_THREADS=4 python ../analysis/eval_intrinsic_dimension.py ../configs/\"$dataset\"/model/config1.yaml model-latent NA; \\\n                                                 OMP_NUM_THREADS=4 python ../analysis/eval_intrinsic_dimension.py ../configs/\"$dataset\"/model/config2.yaml model-latent NA; \\\n                                                 OMP_NUM_THREADS=4 python ../analysis/eval_intrinsic_dimension.py ../configs/\"$dataset\"/model/config3.yaml model-latent NA; \\\n                                                 exec sh\";"
  },
  {
    "path": "scripts/encoder_decoder_eval.sh",
    "content": "#!/bin/bash\n\ndataset=$1\ngpu=$2\n\necho \"==========================================================================================\"\necho \"============== Evaluating encoder-decoder model on: $dataset (gpu id: $gpu) ==============\"\necho \"==========================================================================================\"\n\nscreen -S eval-\"$dataset\" -dm bash -c \"CUDA_VISIBLE_DEVICES=\"$gpu\" python ../eval.py ../configs/\"$dataset\"/model/config1.yaml ./logs_\"$dataset\"_encoder-decoder_1/lightning_logs/checkpoints NA eval-eval NA; \\\n                                       CUDA_VISIBLE_DEVICES=\"$gpu\" python ../eval.py ../configs/\"$dataset\"/model/config2.yaml ./logs_\"$dataset\"_encoder-decoder_2/lightning_logs/checkpoints NA eval-eval NA; \\\n                                       CUDA_VISIBLE_DEVICES=\"$gpu\" python ../eval.py ../configs/\"$dataset\"/model/config3.yaml ./logs_\"$dataset\"_encoder-decoder_3/lightning_logs/checkpoints NA eval-eval NA; \\\n                                       exec sh\";"
  },
  {
    "path": "scripts/encoder_decoder_eval_gather.sh",
    "content": "#!/bin/bash\n\ndataset=$1\ngpu=$2\n\necho \"======================================================================================================\"\necho \"============== Evaluating (Gathering) encoder-decoder model on: $dataset (gpu id: $gpu) ==============\"\necho \"======================================================================================================\"\n\nscreen -S eval-\"$dataset\" -dm bash -c \"CUDA_VISIBLE_DEVICES=\"$gpu\" python ../eval.py ../configs/\"$dataset\"/model/config1.yaml ./logs_\"$dataset\"_encoder-decoder_1/lightning_logs/checkpoints NA eval-train NA; \\\n                                       CUDA_VISIBLE_DEVICES=\"$gpu\" python ../eval.py ../configs/\"$dataset\"/model/config2.yaml ./logs_\"$dataset\"_encoder-decoder_2/lightning_logs/checkpoints NA eval-train NA; \\\n                                       CUDA_VISIBLE_DEVICES=\"$gpu\" python ../eval.py ../configs/\"$dataset\"/model/config3.yaml ./logs_\"$dataset\"_encoder-decoder_3/lightning_logs/checkpoints NA eval-train NA; \\\n                                       exec sh\";"
  },
  {
    "path": "scripts/encoder_decoder_long_term_model_rollout.sh",
    "content": "#!/bin/bash\n\ndataset=$1\ngpu=$2\n\necho \"=======================================================================================================\"\necho \"============== Long-term model rollout encoder-decoder model on: $dataset (gpu id: $gpu) ==============\"\necho \"=======================================================================================================\"\n\nscreen -S eval-\"$dataset\"-longterm-modelrollout -dm bash -c \"CUDA_VISIBLE_DEVICES=\"$gpu\" python ../pred.py model-rollout ../configs/\"$dataset\"/model/config1.yaml ./logs_\"$dataset\"_encoder-decoder_1/lightning_logs/checkpoints NA NA 60; \\\n                                                             CUDA_VISIBLE_DEVICES=\"$gpu\" python ../pred.py model-rollout ../configs/\"$dataset\"/model/config2.yaml ./logs_\"$dataset\"_encoder-decoder_2/lightning_logs/checkpoints NA NA 60; \\\n                                                             CUDA_VISIBLE_DEVICES=\"$gpu\" python ../pred.py model-rollout ../configs/\"$dataset\"/model/config3.yaml ./logs_\"$dataset\"_encoder-decoder_3/lightning_logs/checkpoints NA NA 60; \\\n                                                             exec sh\";"
  },
  {
    "path": "scripts/encoder_decoder_long_term_model_rollout_perturb_all.sh",
    "content": "#!/bin/bash\n\ndataset=$1\ngpu=$2\n\necho \"=======================================================================================================\"\necho \"============== Long-term model rollout encoder-decoder model on: $dataset (gpu id: $gpu) ==============\"\necho \"=======================================================================================================\"\n\nscreen -S eval-\"$dataset\"-longterm-modelrollout -dm bash -c \"CUDA_VISIBLE_DEVICES=\"$gpu\" python ../pred.py model-rollout-perturb-background_replace-8 ../configs/\"$dataset\"/model/config1.yaml ./logs_\"$dataset\"_encoder-decoder_1/lightning_logs/checkpoints NA NA 60; \\\n                                                             CUDA_VISIBLE_DEVICES=\"$gpu\" python ../pred.py model-rollout-perturb-background_replace-8 ../configs/\"$dataset\"/model/config2.yaml ./logs_\"$dataset\"_encoder-decoder_2/lightning_logs/checkpoints NA NA 60; \\\n                                                             CUDA_VISIBLE_DEVICES=\"$gpu\" python ../pred.py model-rollout-perturb-background_replace-8 ../configs/\"$dataset\"/model/config3.yaml ./logs_\"$dataset\"_encoder-decoder_3/lightning_logs/checkpoints NA NA 60; \\\n                                                             \n                                                             CUDA_VISIBLE_DEVICES=\"$gpu\" python ../pred.py model-rollout-perturb-white_noise-5 ../configs/\"$dataset\"/model/config1.yaml ./logs_\"$dataset\"_encoder-decoder_1/lightning_logs/checkpoints NA NA 60; \\\n                                                             CUDA_VISIBLE_DEVICES=\"$gpu\" python ../pred.py model-rollout-perturb-white_noise-5 ../configs/\"$dataset\"/model/config2.yaml ./logs_\"$dataset\"_encoder-decoder_2/lightning_logs/checkpoints NA NA 60; \\\n                                                             CUDA_VISIBLE_DEVICES=\"$gpu\" python ../pred.py model-rollout-perturb-white_noise-5 ../configs/\"$dataset\"/model/config3.yaml ./logs_\"$dataset\"_encoder-decoder_3/lightning_logs/checkpoints NA NA 60; \\\n                                                             \n                                                             CUDA_VISIBLE_DEVICES=\"$gpu\" python ../pred.py model-rollout-perturb-white_noise-7 ../configs/\"$dataset\"/model/config1.yaml ./logs_\"$dataset\"_encoder-decoder_1/lightning_logs/checkpoints NA NA 60; \\\n                                                             CUDA_VISIBLE_DEVICES=\"$gpu\" python ../pred.py model-rollout-perturb-white_noise-7 ../configs/\"$dataset\"/model/config2.yaml ./logs_\"$dataset\"_encoder-decoder_2/lightning_logs/checkpoints NA NA 60; \\\n                                                             CUDA_VISIBLE_DEVICES=\"$gpu\" python ../pred.py model-rollout-perturb-white_noise-7 ../configs/\"$dataset\"/model/config3.yaml ./logs_\"$dataset\"_encoder-decoder_3/lightning_logs/checkpoints NA NA 60; \\\n                                                             \n                                                             CUDA_VISIBLE_DEVICES=\"$gpu\" python ../pred.py model-rollout-perturb-white_noise-8 ../configs/\"$dataset\"/model/config1.yaml ./logs_\"$dataset\"_encoder-decoder_1/lightning_logs/checkpoints NA NA 60; \\\n                                                             CUDA_VISIBLE_DEVICES=\"$gpu\" python ../pred.py model-rollout-perturb-white_noise-8 ../configs/\"$dataset\"/model/config2.yaml ./logs_\"$dataset\"_encoder-decoder_2/lightning_logs/checkpoints NA NA 60; \\\n                                                             CUDA_VISIBLE_DEVICES=\"$gpu\" python ../pred.py model-rollout-perturb-white_noise-8 ../configs/\"$dataset\"/model/config3.yaml ./logs_\"$dataset\"_encoder-decoder_3/lightning_logs/checkpoints NA NA 60; \\\n                                                             \n                                                             CUDA_VISIBLE_DEVICES=\"$gpu\" python ../pred.py model-rollout-perturb-background_cover-5 ../configs/\"$dataset\"/model/config1.yaml ./logs_\"$dataset\"_encoder-decoder_1/lightning_logs/checkpoints NA NA 60; \\\n                                                             CUDA_VISIBLE_DEVICES=\"$gpu\" python ../pred.py model-rollout-perturb-background_cover-5 ../configs/\"$dataset\"/model/config2.yaml ./logs_\"$dataset\"_encoder-decoder_2/lightning_logs/checkpoints NA NA 60; \\\n                                                             CUDA_VISIBLE_DEVICES=\"$gpu\" python ../pred.py model-rollout-perturb-background_cover-5 ../configs/\"$dataset\"/model/config3.yaml ./logs_\"$dataset\"_encoder-decoder_3/lightning_logs/checkpoints NA NA 60; \\\n                                                             \n                                                             CUDA_VISIBLE_DEVICES=\"$gpu\" python ../pred.py model-rollout-perturb-background_cover-7 ../configs/\"$dataset\"/model/config1.yaml ./logs_\"$dataset\"_encoder-decoder_1/lightning_logs/checkpoints NA NA 60; \\\n                                                             CUDA_VISIBLE_DEVICES=\"$gpu\" python ../pred.py model-rollout-perturb-background_cover-7 ../configs/\"$dataset\"/model/config2.yaml ./logs_\"$dataset\"_encoder-decoder_2/lightning_logs/checkpoints NA NA 60; \\\n                                                             CUDA_VISIBLE_DEVICES=\"$gpu\" python ../pred.py model-rollout-perturb-background_cover-7 ../configs/\"$dataset\"/model/config3.yaml ./logs_\"$dataset\"_encoder-decoder_3/lightning_logs/checkpoints NA NA 60; \\\n                                                             \n                                                             CUDA_VISIBLE_DEVICES=\"$gpu\" python ../pred.py model-rollout-perturb-background_cover-8 ../configs/\"$dataset\"/model/config1.yaml ./logs_\"$dataset\"_encoder-decoder_1/lightning_logs/checkpoints NA NA 60; \\\n                                                             CUDA_VISIBLE_DEVICES=\"$gpu\" python ../pred.py model-rollout-perturb-background_cover-8 ../configs/\"$dataset\"/model/config2.yaml ./logs_\"$dataset\"_encoder-decoder_2/lightning_logs/checkpoints NA NA 60; \\\n                                                             CUDA_VISIBLE_DEVICES=\"$gpu\" python ../pred.py model-rollout-perturb-background_cover-8 ../configs/\"$dataset\"/model/config3.yaml ./logs_\"$dataset\"_encoder-decoder_3/lightning_logs/checkpoints NA NA 60; \\\n                                                             \n                                                             CUDA_VISIBLE_DEVICES=\"$gpu\" python ../pred.py model-rollout-perturb-background_replace-5 ../configs/\"$dataset\"/model/config1.yaml ./logs_\"$dataset\"_encoder-decoder_1/lightning_logs/checkpoints NA NA 60; \\\n                                                             CUDA_VISIBLE_DEVICES=\"$gpu\" python ../pred.py model-rollout-perturb-background_replace-5 ../configs/\"$dataset\"/model/config2.yaml ./logs_\"$dataset\"_encoder-decoder_2/lightning_logs/checkpoints NA NA 60; \\\n                                                             CUDA_VISIBLE_DEVICES=\"$gpu\" python ../pred.py model-rollout-perturb-background_replace-5 ../configs/\"$dataset\"/model/config3.yaml ./logs_\"$dataset\"_encoder-decoder_3/lightning_logs/checkpoints NA NA 60; \\\n                                                             \n                                                             CUDA_VISIBLE_DEVICES=\"$gpu\" python ../pred.py model-rollout-perturb-background_replace-7 ../configs/\"$dataset\"/model/config1.yaml ./logs_\"$dataset\"_encoder-decoder_1/lightning_logs/checkpoints NA NA 60; \\\n                                                             CUDA_VISIBLE_DEVICES=\"$gpu\" python ../pred.py model-rollout-perturb-background_replace-7 ../configs/\"$dataset\"/model/config2.yaml ./logs_\"$dataset\"_encoder-decoder_2/lightning_logs/checkpoints NA NA 60; \\\n                                                             CUDA_VISIBLE_DEVICES=\"$gpu\" python ../pred.py model-rollout-perturb-background_replace-7 ../configs/\"$dataset\"/model/config3.yaml ./logs_\"$dataset\"_encoder-decoder_3/lightning_logs/checkpoints NA NA 60; \\\n                                                             \n                                                             exec sh\";"
  },
  {
    "path": "scripts/encoder_decoder_train.sh",
    "content": "#!/bin/bash\n\ndataset=$1\ngpu=$2\n\necho \"========================================================================================\"\necho \"============== Training encoder-decoder model on: $dataset (gpu id: $gpu) ==============\"\necho \"========================================================================================\"\n\nscreen -S train-\"$dataset\" -dm bash -c \"CUDA_VISIBLE_DEVICES=\"$gpu\" python ../main.py ../configs/\"$dataset\"/model/config1.yaml; \\\n                                        CUDA_VISIBLE_DEVICES=\"$gpu\" python ../main.py ../configs/\"$dataset\"/model/config2.yaml; \\\n                                        CUDA_VISIBLE_DEVICES=\"$gpu\" python ../main.py ../configs/\"$dataset\"/model/config3.yaml; \\\n                                        exec sh\";"
  },
  {
    "path": "scripts/latentpred_train.sh",
    "content": "#!/bin/bash\n\ndataset=$1\ngpu=$2\n\necho \"========================================================================================\"\necho \"================ Training latentpred model on: $dataset (gpu id: $gpu) =================\"\necho \"========================================================================================\"\n\nscreen -S train-\"$dataset\" -dm bash -c \"CUDA_VISIBLE_DEVICES=\"$gpu\" python ../main.py latentpred ../configs/\"$dataset\"/latentpred/config1.yaml ./logs_\"$dataset\"_encoder-decoder-64_1/lightning_logs/checkpoints/ ./logs_\"$dataset\"_refine-64_1/lightning_logs/checkpoints/; \\\n                                        CUDA_VISIBLE_DEVICES=\"$gpu\" python ../main.py latentpred ../configs/\"$dataset\"/latentpred/config2.yaml ./logs_\"$dataset\"_encoder-decoder-64_2/lightning_logs/checkpoints/ ./logs_\"$dataset\"_refine-64_2/lightning_logs/checkpoints/; \\\n                                        CUDA_VISIBLE_DEVICES=\"$gpu\" python ../main.py latentpred ../configs/\"$dataset\"/latentpred/config3.yaml ./logs_\"$dataset\"_encoder-decoder-64_3/lightning_logs/checkpoints/ ./logs_\"$dataset\"_refine-64_3/lightning_logs/checkpoints/; \\\n                                        exec sh\";"
  },
  {
    "path": "scripts/long_term_eval_stability.sh",
    "content": "#!/bin/bash\n\ndataset=$1\ngpu=$2\n\necho \"============================================================================================\"\necho \"========= Evaluating stability of long term prediction on: $dataset (gpu id: $gpu) =========\"\necho \"============================================================================================\"\n\nscreen -S eval-\"$dataset\"-long-term-stab -dm bash -c \"CUDA_VISIBLE_DEVICES=\"$gpu\" python ../stability.py ../configs/\"$dataset\"/latentpred/config1.yaml ./logs_\"$dataset\"_latent-prediction_1/lightning_logs/checkpoints/ ./logs_\"$dataset\"_encoder-decoder-64_1/lightning_logs/checkpoints/ ./logs_\"$dataset\"_refine-64_1/lightning_logs/checkpoints/ ./logs_\"$dataset\"_encoder-decoder_1/prediction_long_term/model_rollout/ 60; \\\n                                                      CUDA_VISIBLE_DEVICES=\"$gpu\" python ../stability.py ../configs/\"$dataset\"/latentpred/config1.yaml ./logs_\"$dataset\"_latent-prediction_1/lightning_logs/checkpoints/ ./logs_\"$dataset\"_encoder-decoder-64_1/lightning_logs/checkpoints/ ./logs_\"$dataset\"_refine-64_1/lightning_logs/checkpoints/ ./logs_\"$dataset\"_encoder-decoder-64_1/prediction_long_term/model_rollout/ 60; \\\n                                                      CUDA_VISIBLE_DEVICES=\"$gpu\" python ../stability.py ../configs/\"$dataset\"/latentpred/config1.yaml ./logs_\"$dataset\"_latent-prediction_1/lightning_logs/checkpoints/ ./logs_\"$dataset\"_encoder-decoder-64_1/lightning_logs/checkpoints/ ./logs_\"$dataset\"_refine-64_1/lightning_logs/checkpoints/ ./logs_\"$dataset\"_refine-64_1/prediction_long_term/model_rollout/ 60; \\\n                                                      CUDA_VISIBLE_DEVICES=\"$gpu\" python ../stability.py ../configs/\"$dataset\"/latentpred/config2.yaml ./logs_\"$dataset\"_latent-prediction_2/lightning_logs/checkpoints/ ./logs_\"$dataset\"_encoder-decoder-64_2/lightning_logs/checkpoints/ ./logs_\"$dataset\"_refine-64_2/lightning_logs/checkpoints/ ./logs_\"$dataset\"_encoder-decoder_2/prediction_long_term/model_rollout/ 60; \\\n                                                      CUDA_VISIBLE_DEVICES=\"$gpu\" python ../stability.py ../configs/\"$dataset\"/latentpred/config2.yaml ./logs_\"$dataset\"_latent-prediction_2/lightning_logs/checkpoints/ ./logs_\"$dataset\"_encoder-decoder-64_2/lightning_logs/checkpoints/ ./logs_\"$dataset\"_refine-64_2/lightning_logs/checkpoints/ ./logs_\"$dataset\"_encoder-decoder-64_2/prediction_long_term/model_rollout/ 60; \\\n                                                      CUDA_VISIBLE_DEVICES=\"$gpu\" python ../stability.py ../configs/\"$dataset\"/latentpred/config2.yaml ./logs_\"$dataset\"_latent-prediction_2/lightning_logs/checkpoints/ ./logs_\"$dataset\"_encoder-decoder-64_2/lightning_logs/checkpoints/ ./logs_\"$dataset\"_refine-64_2/lightning_logs/checkpoints/ ./logs_\"$dataset\"_refine-64_2/prediction_long_term/model_rollout/ 60; \\\n                                                      CUDA_VISIBLE_DEVICES=\"$gpu\" python ../stability.py ../configs/\"$dataset\"/latentpred/config3.yaml ./logs_\"$dataset\"_latent-prediction_3/lightning_logs/checkpoints/ ./logs_\"$dataset\"_encoder-decoder-64_3/lightning_logs/checkpoints/ ./logs_\"$dataset\"_refine-64_3/lightning_logs/checkpoints/ ./logs_\"$dataset\"_encoder-decoder_3/prediction_long_term/model_rollout/ 60; \\\n                                                      CUDA_VISIBLE_DEVICES=\"$gpu\" python ../stability.py ../configs/\"$dataset\"/latentpred/config3.yaml ./logs_\"$dataset\"_latent-prediction_3/lightning_logs/checkpoints/ ./logs_\"$dataset\"_encoder-decoder-64_3/lightning_logs/checkpoints/ ./logs_\"$dataset\"_refine-64_3/lightning_logs/checkpoints/ ./logs_\"$dataset\"_encoder-decoder-64_3/prediction_long_term/model_rollout/ 60; \\\n                                                      CUDA_VISIBLE_DEVICES=\"$gpu\" python ../stability.py ../configs/\"$dataset\"/latentpred/config3.yaml ./logs_\"$dataset\"_latent-prediction_3/lightning_logs/checkpoints/ ./logs_\"$dataset\"_encoder-decoder-64_3/lightning_logs/checkpoints/ ./logs_\"$dataset\"_refine-64_3/lightning_logs/checkpoints/ ./logs_\"$dataset\"_refine-64_3/prediction_long_term/model_rollout/ 60; \\\n                                                      exec sh\";"
  },
  {
    "path": "scripts/long_term_eval_stability_hybrid.sh",
    "content": "#!/bin/bash\n\ndataset=$1\ngpu=$2\nstep=$3\n\necho \"============================================================================================\"\necho \"========= Evaluating stability of long term prediction on: $dataset (gpu id: $gpu) =========\"\necho \"============================================================================================\"\n\nscreen -S eval-\"$dataset\"-long-term-stab -dm bash -c \"CUDA_VISIBLE_DEVICES=\"$gpu\" python ../stability.py ../configs/\"$dataset\"/latentpred/config1.yaml ./logs_\"$dataset\"_latent-prediction_1/lightning_logs/checkpoints/ ./logs_\"$dataset\"_encoder-decoder-64_1/lightning_logs/checkpoints/ ./logs_\"$dataset\"_refine-64_1/lightning_logs/checkpoints/ ./logs_\"$dataset\"_refine-64_1/prediction_long_term/hybrid_rollout_$step/ 60; \\\n                                                      CUDA_VISIBLE_DEVICES=\"$gpu\" python ../stability.py ../configs/\"$dataset\"/latentpred/config2.yaml ./logs_\"$dataset\"_latent-prediction_2/lightning_logs/checkpoints/ ./logs_\"$dataset\"_encoder-decoder-64_2/lightning_logs/checkpoints/ ./logs_\"$dataset\"_refine-64_2/lightning_logs/checkpoints/ ./logs_\"$dataset\"_refine-64_2/prediction_long_term/hybrid_rollout_$step/ 60; \\\n                                                      CUDA_VISIBLE_DEVICES=\"$gpu\" python ../stability.py ../configs/\"$dataset\"/latentpred/config3.yaml ./logs_\"$dataset\"_latent-prediction_3/lightning_logs/checkpoints/ ./logs_\"$dataset\"_encoder-decoder-64_3/lightning_logs/checkpoints/ ./logs_\"$dataset\"_refine-64_3/lightning_logs/checkpoints/ ./logs_\"$dataset\"_refine-64_3/prediction_long_term/hybrid_rollout_$step/ 60; \\\n                                                      exec sh\";"
  },
  {
    "path": "scripts/refine_64_eval.sh",
    "content": "#!/bin/bash\n\ndataset=$1\ngpu=$2\n\necho \"====================================================================================\"\necho \"============== Evaluating refine-64 model on: $dataset (gpu id: $gpu) ==============\"\necho \"====================================================================================\"\n\nscreen -S eval-\"$dataset\"-refine-64 -dm bash -c \"CUDA_VISIBLE_DEVICES=\"$gpu\" python ../eval.py ../configs/\"$dataset\"/refine64/config1.yaml ./logs_\"$dataset\"_refine-64_1/lightning_logs/checkpoints ./logs_\"$dataset\"_encoder-decoder-64_1/lightning_logs/checkpoints eval-eval NA; \\\n                                                 CUDA_VISIBLE_DEVICES=\"$gpu\" python ../eval.py ../configs/\"$dataset\"/refine64/config2.yaml ./logs_\"$dataset\"_refine-64_2/lightning_logs/checkpoints ./logs_\"$dataset\"_encoder-decoder-64_2/lightning_logs/checkpoints eval-eval NA; \\\n                                                 CUDA_VISIBLE_DEVICES=\"$gpu\" python ../eval.py ../configs/\"$dataset\"/refine64/config3.yaml ./logs_\"$dataset\"_refine-64_3/lightning_logs/checkpoints ./logs_\"$dataset\"_encoder-decoder-64_3/lightning_logs/checkpoints eval-eval NA; \\\n                                                 exec sh\";"
  },
  {
    "path": "scripts/refine_64_eval_gather.sh",
    "content": "#!/bin/bash\n\ndataset=$1\ngpu=$2\n\necho \"================================================================================================\"\necho \"============== Evaluating (Gathering) refine-64 model on: $dataset (gpu id: $gpu) ==============\"\necho \"================================================================================================\"\n\nscreen -S eval-\"$dataset\"-refine-64 -dm bash -c \"CUDA_VISIBLE_DEVICES=\"$gpu\" python ../eval.py ../configs/\"$dataset\"/refine64/config1.yaml ./logs_\"$dataset\"_refine-64_1/lightning_logs/checkpoints ./logs_\"$dataset\"_encoder-decoder-64_1/lightning_logs/checkpoints eval-refine-train NA; \\\n                                                 CUDA_VISIBLE_DEVICES=\"$gpu\" python ../eval.py ../configs/\"$dataset\"/refine64/config2.yaml ./logs_\"$dataset\"_refine-64_2/lightning_logs/checkpoints ./logs_\"$dataset\"_encoder-decoder-64_2/lightning_logs/checkpoints eval-refine-train NA; \\\n                                                 CUDA_VISIBLE_DEVICES=\"$gpu\" python ../eval.py ../configs/\"$dataset\"/refine64/config3.yaml ./logs_\"$dataset\"_refine-64_3/lightning_logs/checkpoints ./logs_\"$dataset\"_encoder-decoder-64_3/lightning_logs/checkpoints eval-refine-train NA; \\\n                                                 exec sh\";"
  },
  {
    "path": "scripts/refine_64_long_term_hybrid_rollout.sh",
    "content": "#!/bin/bash\n\ndataset=$1\ngpu=$2\nstep=$3\n\necho \"================================================================================================\"\necho \"======= Long-term hybrid-$step model rollout refine-64 model on: $dataset (gpu id: $gpu) =======\"\necho \"================================================================================================\"\n\nscreen -S eval-\"$dataset\"-longterm-hybridrollout -dm bash -c \"CUDA_VISIBLE_DEVICES=\"$gpu\" python ../pred.py hybrid-$step ../configs/\"$dataset\"/refine64/config1.yaml ./logs_\"$dataset\"_encoder-decoder-64_1/lightning_logs/checkpoints ./logs_\"$dataset\"_refine-64_1/lightning_logs/checkpoints NA 60; \\\n                                                              CUDA_VISIBLE_DEVICES=\"$gpu\" python ../pred.py hybrid-$step ../configs/\"$dataset\"/refine64/config2.yaml ./logs_\"$dataset\"_encoder-decoder-64_2/lightning_logs/checkpoints ./logs_\"$dataset\"_refine-64_2/lightning_logs/checkpoints NA 60; \\\n                                                              CUDA_VISIBLE_DEVICES=\"$gpu\" python ../pred.py hybrid-$step ../configs/\"$dataset\"/refine64/config3.yaml ./logs_\"$dataset\"_encoder-decoder-64_3/lightning_logs/checkpoints ./logs_\"$dataset\"_refine-64_3/lightning_logs/checkpoints NA 60; \\\n                                                              exec sh\";"
  },
  {
    "path": "scripts/refine_64_long_term_model_rollout.sh",
    "content": "#!/bin/bash\n\ndataset=$1\ngpu=$2\n\necho \"=================================================================================================\"\necho \"============== Long-term model rollout refine-64 model on: $dataset (gpu id: $gpu) ==============\"\necho \"=================================================================================================\"\n\nscreen -S eval-\"$dataset\"-longterm-modelrollout -dm bash -c \"CUDA_VISIBLE_DEVICES=\"$gpu\" python ../pred.py model-rollout ../configs/\"$dataset\"/refine64/config1.yaml ./logs_\"$dataset\"_encoder-decoder-64_1/lightning_logs/checkpoints ./logs_\"$dataset\"_refine-64_1/lightning_logs/checkpoints NA 60; \\\n                                                             CUDA_VISIBLE_DEVICES=\"$gpu\" python ../pred.py model-rollout ../configs/\"$dataset\"/refine64/config2.yaml ./logs_\"$dataset\"_encoder-decoder-64_2/lightning_logs/checkpoints ./logs_\"$dataset\"_refine-64_2/lightning_logs/checkpoints NA 60; \\\n                                                             CUDA_VISIBLE_DEVICES=\"$gpu\" python ../pred.py model-rollout ../configs/\"$dataset\"/refine64/config3.yaml ./logs_\"$dataset\"_encoder-decoder-64_3/lightning_logs/checkpoints ./logs_\"$dataset\"_refine-64_3/lightning_logs/checkpoints NA 60; \\\n                                                             exec sh\";"
  },
  {
    "path": "scripts/refine_64_long_term_model_rollout_perturb_all.sh",
    "content": "#!/bin/bash\n\ndataset=$1\ngpu=$2\n\necho \"=================================================================================================\"\necho \"============== Long-term model rollout refine-64 model on: $dataset (gpu id: $gpu) ==============\"\necho \"=================================================================================================\"\n\nscreen -S eval-\"$dataset\"-longterm-modelrollout -dm bash -c \"CUDA_VISIBLE_DEVICES=\"$gpu\" python ../pred.py model-rollout-perturb-background_cover-5 ../configs/\"$dataset\"/refine64/config1.yaml ./logs_\"$dataset\"_encoder-decoder-64_1/lightning_logs/checkpoints ./logs_\"$dataset\"_refine-64_1/lightning_logs/checkpoints NA 60; \\\n                                                             CUDA_VISIBLE_DEVICES=\"$gpu\" python ../pred.py model-rollout-perturb-background_cover-5 ../configs/\"$dataset\"/refine64/config2.yaml ./logs_\"$dataset\"_encoder-decoder-64_2/lightning_logs/checkpoints ./logs_\"$dataset\"_refine-64_2/lightning_logs/checkpoints NA 60; \\\n                                                             CUDA_VISIBLE_DEVICES=\"$gpu\" python ../pred.py model-rollout-perturb-background_cover-5 ../configs/\"$dataset\"/refine64/config3.yaml ./logs_\"$dataset\"_encoder-decoder-64_3/lightning_logs/checkpoints ./logs_\"$dataset\"_refine-64_3/lightning_logs/checkpoints NA 60; \\\n                                                             \n                                                             CUDA_VISIBLE_DEVICES=\"$gpu\" python ../pred.py model-rollout-perturb-background_cover-7 ../configs/\"$dataset\"/refine64/config1.yaml ./logs_\"$dataset\"_encoder-decoder-64_1/lightning_logs/checkpoints ./logs_\"$dataset\"_refine-64_1/lightning_logs/checkpoints NA 60; \\\n                                                             CUDA_VISIBLE_DEVICES=\"$gpu\" python ../pred.py model-rollout-perturb-background_cover-7 ../configs/\"$dataset\"/refine64/config2.yaml ./logs_\"$dataset\"_encoder-decoder-64_2/lightning_logs/checkpoints ./logs_\"$dataset\"_refine-64_2/lightning_logs/checkpoints NA 60; \\\n                                                             CUDA_VISIBLE_DEVICES=\"$gpu\" python ../pred.py model-rollout-perturb-background_cover-7 ../configs/\"$dataset\"/refine64/config3.yaml ./logs_\"$dataset\"_encoder-decoder-64_3/lightning_logs/checkpoints ./logs_\"$dataset\"_refine-64_3/lightning_logs/checkpoints NA 60; \\\n                                                             \n                                                             CUDA_VISIBLE_DEVICES=\"$gpu\" python ../pred.py model-rollout-perturb-background_cover-8 ../configs/\"$dataset\"/refine64/config1.yaml ./logs_\"$dataset\"_encoder-decoder-64_1/lightning_logs/checkpoints ./logs_\"$dataset\"_refine-64_1/lightning_logs/checkpoints NA 60; \\\n                                                             CUDA_VISIBLE_DEVICES=\"$gpu\" python ../pred.py model-rollout-perturb-background_cover-8 ../configs/\"$dataset\"/refine64/config2.yaml ./logs_\"$dataset\"_encoder-decoder-64_2/lightning_logs/checkpoints ./logs_\"$dataset\"_refine-64_2/lightning_logs/checkpoints NA 60; \\\n                                                             CUDA_VISIBLE_DEVICES=\"$gpu\" python ../pred.py model-rollout-perturb-background_cover-8 ../configs/\"$dataset\"/refine64/config3.yaml ./logs_\"$dataset\"_encoder-decoder-64_3/lightning_logs/checkpoints ./logs_\"$dataset\"_refine-64_3/lightning_logs/checkpoints NA 60; \\\n                                                             \n                                                             CUDA_VISIBLE_DEVICES=\"$gpu\" python ../pred.py model-rollout-perturb-background_replace-5 ../configs/\"$dataset\"/refine64/config1.yaml ./logs_\"$dataset\"_encoder-decoder-64_1/lightning_logs/checkpoints ./logs_\"$dataset\"_refine-64_1/lightning_logs/checkpoints NA 60; \\\n                                                             CUDA_VISIBLE_DEVICES=\"$gpu\" python ../pred.py model-rollout-perturb-background_replace-5 ../configs/\"$dataset\"/refine64/config2.yaml ./logs_\"$dataset\"_encoder-decoder-64_2/lightning_logs/checkpoints ./logs_\"$dataset\"_refine-64_2/lightning_logs/checkpoints NA 60; \\\n                                                             CUDA_VISIBLE_DEVICES=\"$gpu\" python ../pred.py model-rollout-perturb-background_replace-5 ../configs/\"$dataset\"/refine64/config3.yaml ./logs_\"$dataset\"_encoder-decoder-64_3/lightning_logs/checkpoints ./logs_\"$dataset\"_refine-64_3/lightning_logs/checkpoints NA 60; \\\n                                                             \n                                                             CUDA_VISIBLE_DEVICES=\"$gpu\" python ../pred.py model-rollout-perturb-background_replace-7 ../configs/\"$dataset\"/refine64/config1.yaml ./logs_\"$dataset\"_encoder-decoder-64_1/lightning_logs/checkpoints ./logs_\"$dataset\"_refine-64_1/lightning_logs/checkpoints NA 60; \\\n                                                             CUDA_VISIBLE_DEVICES=\"$gpu\" python ../pred.py model-rollout-perturb-background_replace-7 ../configs/\"$dataset\"/refine64/config2.yaml ./logs_\"$dataset\"_encoder-decoder-64_2/lightning_logs/checkpoints ./logs_\"$dataset\"_refine-64_2/lightning_logs/checkpoints NA 60; \\\n                                                             CUDA_VISIBLE_DEVICES=\"$gpu\" python ../pred.py model-rollout-perturb-background_replace-7 ../configs/\"$dataset\"/refine64/config3.yaml ./logs_\"$dataset\"_encoder-decoder-64_3/lightning_logs/checkpoints ./logs_\"$dataset\"_refine-64_3/lightning_logs/checkpoints NA 60; \\\n                                                             \n                                                             CUDA_VISIBLE_DEVICES=\"$gpu\" python ../pred.py model-rollout-perturb-background_replace-8 ../configs/\"$dataset\"/refine64/config1.yaml ./logs_\"$dataset\"_encoder-decoder-64_1/lightning_logs/checkpoints ./logs_\"$dataset\"_refine-64_1/lightning_logs/checkpoints NA 60; \\\n                                                             CUDA_VISIBLE_DEVICES=\"$gpu\" python ../pred.py model-rollout-perturb-background_replace-8 ../configs/\"$dataset\"/refine64/config2.yaml ./logs_\"$dataset\"_encoder-decoder-64_2/lightning_logs/checkpoints ./logs_\"$dataset\"_refine-64_2/lightning_logs/checkpoints NA 60; \\\n                                                             CUDA_VISIBLE_DEVICES=\"$gpu\" python ../pred.py model-rollout-perturb-background_replace-8 ../configs/\"$dataset\"/refine64/config3.yaml ./logs_\"$dataset\"_encoder-decoder-64_3/lightning_logs/checkpoints ./logs_\"$dataset\"_refine-64_3/lightning_logs/checkpoints NA 60; \\\n                                                             \n                                                             CUDA_VISIBLE_DEVICES=\"$gpu\" python ../pred.py model-rollout-perturb-white_noise-5 ../configs/\"$dataset\"/refine64/config1.yaml ./logs_\"$dataset\"_encoder-decoder-64_1/lightning_logs/checkpoints ./logs_\"$dataset\"_refine-64_1/lightning_logs/checkpoints NA 60; \\\n                                                             CUDA_VISIBLE_DEVICES=\"$gpu\" python ../pred.py model-rollout-perturb-white_noise-5 ../configs/\"$dataset\"/refine64/config2.yaml ./logs_\"$dataset\"_encoder-decoder-64_2/lightning_logs/checkpoints ./logs_\"$dataset\"_refine-64_2/lightning_logs/checkpoints NA 60; \\\n                                                             CUDA_VISIBLE_DEVICES=\"$gpu\" python ../pred.py model-rollout-perturb-white_noise-5 ../configs/\"$dataset\"/refine64/config3.yaml ./logs_\"$dataset\"_encoder-decoder-64_3/lightning_logs/checkpoints ./logs_\"$dataset\"_refine-64_3/lightning_logs/checkpoints NA 60; \\\n                                                             \n                                                             CUDA_VISIBLE_DEVICES=\"$gpu\" python ../pred.py model-rollout-perturb-white_noise-7 ../configs/\"$dataset\"/refine64/config1.yaml ./logs_\"$dataset\"_encoder-decoder-64_1/lightning_logs/checkpoints ./logs_\"$dataset\"_refine-64_1/lightning_logs/checkpoints NA 60; \\\n                                                             CUDA_VISIBLE_DEVICES=\"$gpu\" python ../pred.py model-rollout-perturb-white_noise-7 ../configs/\"$dataset\"/refine64/config2.yaml ./logs_\"$dataset\"_encoder-decoder-64_2/lightning_logs/checkpoints ./logs_\"$dataset\"_refine-64_2/lightning_logs/checkpoints NA 60; \\\n                                                             CUDA_VISIBLE_DEVICES=\"$gpu\" python ../pred.py model-rollout-perturb-white_noise-7 ../configs/\"$dataset\"/refine64/config3.yaml ./logs_\"$dataset\"_encoder-decoder-64_3/lightning_logs/checkpoints ./logs_\"$dataset\"_refine-64_3/lightning_logs/checkpoints NA 60; \\\n                                                             \n                                                             CUDA_VISIBLE_DEVICES=\"$gpu\" python ../pred.py model-rollout-perturb-white_noise-8 ../configs/\"$dataset\"/refine64/config1.yaml ./logs_\"$dataset\"_encoder-decoder-64_1/lightning_logs/checkpoints ./logs_\"$dataset\"_refine-64_1/lightning_logs/checkpoints NA 60; \\\n                                                             CUDA_VISIBLE_DEVICES=\"$gpu\" python ../pred.py model-rollout-perturb-white_noise-8 ../configs/\"$dataset\"/refine64/config2.yaml ./logs_\"$dataset\"_encoder-decoder-64_2/lightning_logs/checkpoints ./logs_\"$dataset\"_refine-64_2/lightning_logs/checkpoints NA 60; \\\n                                                             CUDA_VISIBLE_DEVICES=\"$gpu\" python ../pred.py model-rollout-perturb-white_noise-8 ../configs/\"$dataset\"/refine64/config3.yaml ./logs_\"$dataset\"_encoder-decoder-64_3/lightning_logs/checkpoints ./logs_\"$dataset\"_refine-64_3/lightning_logs/checkpoints NA 60; \\\n                                                             \n                                                             exec sh\";"
  },
  {
    "path": "scripts/refine_64_train.sh",
    "content": "#!/bin/bash\n\ndataset=$1\ngpu=$2\n\necho \"===================================================================================\"\necho \"============== Training refine-64 model on: $dataset (gpu id: $gpu) ===============\"\necho \"===================================================================================\"\n\nscreen -S train-\"$dataset\"-refine-64 -dm bash -c \"CUDA_VISIBLE_DEVICES=\"$gpu\" python ../main.py ../configs/\"$dataset\"/refine64/config1.yaml; \\\n                                                  CUDA_VISIBLE_DEVICES=\"$gpu\" python ../main.py ../configs/\"$dataset\"/refine64/config2.yaml; \\\n                                                  CUDA_VISIBLE_DEVICES=\"$gpu\" python ../main.py ../configs/\"$dataset\"/refine64/config3.yaml; \\\n                                                  exec sh\";"
  },
  {
    "path": "stability.py",
    "content": "import os\nimport sys\nimport glob\nimport yaml\nimport json\nimport torch\nimport pprint\nimport shutil\nimport numpy as np\nfrom PIL import Image\nfrom tqdm import tqdm\nfrom munch import munchify\nfrom models_latentpred import VisLatentDynamicsModel\nfrom dataset import NeuralPhysDataset\nfrom pytorch_lightning import Trainer, seed_everything\nfrom pytorch_lightning.loggers import TensorBoardLogger\n\n\ndef mkdir(folder):\n    if os.path.exists(folder):\n        shutil.rmtree(folder)\n    os.makedirs(folder)\n\ndef load_config(filepath):\n    with open(filepath, 'r') as stream:\n        try:\n            trainer_params = yaml.safe_load(stream)\n            return trainer_params\n        except yaml.YAMLError as exc:\n            print(exc)\n\ndef seed(cfg):\n    torch.manual_seed(cfg.seed)\n    if cfg.if_cuda:\n        torch.cuda.manual_seed(cfg.seed)\n\n\ndef get_data(filepath):\n    data = Image.open(filepath)\n    data = data.resize((128, 128))\n    data = np.array(data)\n    data = torch.tensor(data / 255.0)\n    data = data.permute(2, 0, 1).float()\n    return data\n\n\ndef main():\n    config_filepath = str(sys.argv[1])\n    checkpoint_filepath = str(sys.argv[2])\n    high_dim_checkpoint_filepath = str(sys.argv[3])\n    refine_checkpoint_filepath = str(sys.argv[4])\n    pred_save_path = str(sys.argv[5])\n    pred_len = int(sys.argv[6])\n    \n    cfg = load_config(filepath=config_filepath)\n    pprint.pprint(cfg)\n    cfg = munchify(cfg)\n    seed(cfg)\n    seed_everything(cfg.seed)\n\n    log_dir = '_'.join([cfg.log_dir,\n                        cfg.dataset,\n                        cfg.model_name,\n                        str(cfg.seed)])\n\n    model = VisLatentDynamicsModel(lr=cfg.lr,\n                                   seed=cfg.seed,\n                                   if_cuda=cfg.if_cuda,\n                                   if_test=True,\n                                   gamma=cfg.gamma,\n                                   log_dir=log_dir,\n                                   train_batch=cfg.train_batch,\n                                   val_batch=cfg.val_batch,\n                                   test_batch=cfg.test_batch,\n                                   num_workers=cfg.num_workers,\n                                   model_name=cfg.model_name,\n                                   data_filepath=cfg.data_filepath,\n                                   dataset=cfg.dataset,\n                                   lr_schedule=cfg.lr_schedule)\n\n    model.load_model(checkpoint_filepath)\n    model.load_high_dim_refine_model(high_dim_checkpoint_filepath, refine_checkpoint_filepath)\n    model.extract_decoder_from_refine_model()\n    model = model.to('cuda')\n    model.eval()\n    model.freeze()\n\n    # get all the test video ids\n    data_filepath_base = os.path.join(cfg.data_filepath, cfg.dataset)\n    with open(os.path.join('../datainfo', cfg.dataset, f'data_split_dict_{cfg.seed}.json'), 'r') as file:\n        seq_dict = json.load(file)\n    test_vid_ids = seq_dict['test']\n    \n    errors = []\n    for p_vid_idx in tqdm(test_vid_ids):\n        data_vid_filepath = os.path.join(data_filepath_base, str(p_vid_idx))\n        pred_vid_filepath = os.path.join(pred_save_path, str(p_vid_idx))\n        total_num_frames = len(os.listdir(data_vid_filepath))\n        suf = os.listdir(data_vid_filepath)[0].split('.')[-1]\n        \n        error_p = []\n        for start_frame_idx in range(total_num_frames - 3):\n            if start_frame_idx % 2 != 0:\n                continue\n            # get the data\n            if start_frame_idx % pred_len == 0:\n                data = [get_data(os.path.join(data_vid_filepath, f'{start_frame_idx}.{suf}')), \n                        get_data(os.path.join(data_vid_filepath, f'{start_frame_idx+1}.{suf}'))]\n                data = (torch.cat(data, 2)).unsqueeze(0).cuda()\n            else:\n                data = [get_data(os.path.join(pred_vid_filepath, f'{start_frame_idx}.{suf}')), \n                        get_data(os.path.join(pred_vid_filepath, f'{start_frame_idx+1}.{suf}'))]\n                data = (torch.cat(data, 2)).unsqueeze(0).cuda()\n            # get the target\n            target = [get_data(os.path.join(pred_vid_filepath, f'{start_frame_idx+2}.{suf}')), \n                      get_data(os.path.join(pred_vid_filepath, f'{start_frame_idx+3}.{suf}'))]\n            target = (torch.cat(target, 2)).unsqueeze(0).cuda()\n            # compute latent space error\n            data_state = model.data_to_state(data)\n            target_state = model.data_to_state(target)\n            output_state = model.model(data_state)\n            error_p.append(float(model.loss_func(output_state, target_state).cpu().detach().numpy()))\n        errors.append(error_p)\n    \n    errors = np.array(errors)\n    np.save(os.path.join(pred_save_path, 'stability.npy'), errors)\n\n\nif __name__ == '__main__':\n    main()"
  },
  {
    "path": "utils/common.py",
    "content": "import os\nimport shutil\nimport json\nimport random\nimport numpy as np\nfrom PIL import Image\nimport plotly.graph_objects as go\n\n\ndef mkdir(folder):\n    if os.path.exists(folder):\n        shutil.rmtree(folder)\n    os.makedirs(folder)\n\n\ndef create_random_data_splits(seed, data_filepath, object_name, ratio=0.8):\n    random.seed(seed)\n    seq_dict = {}\n\n    obj_filepath = os.path.join(data_filepath, object_name)\n    num_vids = int(len(os.listdir(obj_filepath)) - 1)\n    vid_id_lst = list(range(num_vids))\n    random.shuffle(vid_id_lst)\n\n    # test\n    start = int(num_vids * (ratio + (1 - ratio) / 2))\n    seq_dict['test'] = vid_id_lst[start:]\n    # val\n    start = int(num_vids * ratio)\n    end = int(num_vids * (ratio + (1 - ratio) / 2))\n    seq_dict['val'] = vid_id_lst[start:end]\n    # train\n    seq_dict['train'] = vid_id_lst[:int(num_vids * ratio)]\n\n    with open(os.path.join(f'./datainfo/{object_name}/data_split_dict_{seed}.json'), 'w') as file:\n        json.dump(seq_dict, file, indent=4)\n\n\ndef separate_eval_image(filepath):\n    img = Image.open(filepath)\n    img.crop((2, 2, 130, 130)).save('input_0.png')\n    img.crop((2, 132, 130, 260)).save('input_1.png')\n    img.crop((2, 262, 130, 390)).save('truth_0.png')\n    img.crop((2, 392, 130, 520)).save('truth_1.png')\n    img.crop((2, 522, 130, 650)).save('output_0.png')\n    img.crop((2, 652, 130, 780)).save('output_1.png')\n\n\ndef collect_long_term_pred_sample(dataset, seed, test_vid_id, hybrid_step, start=0, step=12, stop=60, suf='png'):\n    idv = id_value[dataset]\n    fps = fps_value[dataset]\n    save_path = f'long_term_pred_sample_{dataset}_seed_{seed}_vid_{test_vid_id}' \n    mkdir(save_path)\n    \n    data_filepath = f'/data/kuang/visphysics_data/{dataset}/{test_vid_id}'\n    pred_filepaths = {'dim-8192':f'/data/kuang/logs/logs_{dataset}_encoder-decoder_{seed}/prediction_long_term/model_rollout/{test_vid_id}',\n                      'dim-64':f'/data/kuang/logs/logs_{dataset}_encoder-decoder-64_{seed}/prediction_long_term/model_rollout/{test_vid_id}',\n                      f'dim-{idv}':f'/data/kuang/logs/logs_{dataset}_refine-64_{seed}/prediction_long_term/model_rollout/{test_vid_id}',\n                      f'hybrid-64-{idv}':f'/data/kuang/logs/logs_{dataset}_refine-64_{seed}/prediction_long_term/hybrid_rollout_{hybrid_step}/{test_vid_id}'}\n    \n    Image.open(os.path.join(data_filepath, str(start)+'.'+suf)).save(os.path.join(save_path, 'input_0.png'))\n    Image.open(os.path.join(data_filepath, str(start+1)+'.'+suf)).save(os.path.join(save_path, 'input_1.png'))\n    \n    for k in range(start+step, stop+1, step):\n        data = Image.open(os.path.join(data_filepath, str(k-1)+'.'+suf))\n        data.save(os.path.join(save_path, 'truth_%.1fs.png'%((k-start)/fps)))\n        for scheme in ['dim-8192', 'dim-64', f'dim-{idv}', f'hybrid-64-{idv}']:\n            pred = Image.open(os.path.join(pred_filepaths[scheme], str(k-1)+'.'+suf))\n            pred.save(os.path.join(save_path, scheme+'_%.1fs.png'%((k-start)/fps)))\n\n\n\"\"\"\n    physical system parameters\n\"\"\"\nid_value = {'circular_motion':2, 'reaction_diffusion':2, 'single_pendulum':2,\n            'double_pendulum':4, 'swingstick_non_magnetic':4, 'elastic_pendulum':6,\n            'air_dancer':8, 'lava_lamp':8, 'fire':24}\nfps_value = {'circular_motion':60, 'reaction_diffusion':5, 'single_pendulum':60,\n             'double_pendulum':60, 'swingstick_non_magnetic':60, 'elastic_pendulum':60,\n             'air_dancer':60, 'lava_lamp':2.5, 'fire':24}\n\n\n\"\"\"\n    plot style\n\"\"\"\ncols = ['#df1e1e', '#f0975a', '#1b66cb', '#149650']\n\ncolorscale=[[0.0, \"rgb(49,54,149)\"],\n            [0.1111111111111111, \"rgb(69,117,180)\"],\n            [0.2222222222222222, \"rgb(116,173,209)\"],\n            [0.3333333333333333, \"rgb(171,217,233)\"],\n            [0.4444444444444444, \"rgb(224,243,248)\"],\n            [0.5555555555555556, \"rgb(254,224,144)\"],\n            [0.6666666666666666, \"rgb(253,174,97)\"],\n            [0.7777777777777778, \"rgb(244,109,67)\"],\n            [0.8888888888888888, \"rgb(215,48,39)\"],\n            [1.0, \"rgb(165,0,38)\"]]\n\ndef light_color(col):\n    if col[:3] == 'rgb':\n        rgb = cols[n][4:-1]\n    elif col[0] == '#':\n        rgb = str(tuple(int(col.lstrip('#')[i:i+2],16) for i in (0,2,4)))[1:-1]\n    return 'rgba('+rgb +',0.2)'\n\ndef update_figure(fig):\n    fig.update_xaxes(showgrid=False, tickfont=dict(family=\"sans serif\", size=18, color=\"black\"))\n    fig.update_yaxes(showgrid=False, tickfont=dict(family=\"sans serif\", size=18, color=\"black\"))\n    fig.update_xaxes(showline=True, linewidth=2, linecolor=\"black\")\n    fig.update_yaxes(showline=True, linewidth=2, linecolor=\"black\")\n    fig.update_layout(legend=go.layout.Legend(traceorder=\"normal\",\n                                              font=dict(family=\"sans serif\", size=18, color=\"black\"),\n                    ))\n    fig.update_layout(font=dict(family=\"sans serif\", size=24, color=\"black\"),\n                      paper_bgcolor='rgba(0,0,0,0)',\n                      plot_bgcolor='rgba(0,0,0,0)',\n                     )\n\ndef test():\n    fig = go.Figure()\n    x = np.arange(30)\n    for k in range(4):\n        fig.add_trace(go.Scatter(x=x, y=x*(k+1), mode='lines', line=dict(width=4, color=cols[k])))\n    update_figure(fig)\n    fig.write_image('test.png', scale=4)\n\n\n\"\"\"\ntest code only\n\"\"\"\nif __name__ == '__main__':\n    test()"
  },
  {
    "path": "utils/dimension.py",
    "content": "import numpy as np\nfrom scipy import stats\nimport os\nimport sys\nimport yaml\nfrom munch import munchify\n\ndef load_config(filepath):\n    with open(filepath, 'r') as stream:\n        try:\n            trainer_params = yaml.safe_load(stream)\n            return trainer_params\n        except yaml.YAMLError as exc:\n            print(exc)\n\n\ndef calculate_intrinsic_dimension_statistics(dataset, model_type='model', if_image=False):\n    configs_dir = os.path.join('../configs', dataset, model_type)\n    if if_image:\n        filename = 'intrinsic_dimension_image.npy'\n    else:\n        filename = 'intrinsic_dimension.npy'\n    dims_all = []\n    for config_filepath in os.listdir(configs_dir):\n        cfg = load_config(filepath=os.path.join(configs_dir, config_filepath))\n        cfg = munchify(cfg)\n        log_dir = '_'.join([cfg.log_dir, cfg.dataset, cfg.model_name, str(cfg.seed)])\n        vars_filepath = os.path.join(log_dir, 'variables')\n        dims = np.load(os.path.join(vars_filepath, filename))\n        dims_all.append(dims)\n    dims_all = np.concatenate(dims_all)\n    dim_mean = np.mean(dims_all)\n    dim_std = np.std(dims_all)\n    print('Mean (±std):', '%.2f (±%.2f)' % (dim_mean, dim_std))\n    print('Confidence interval:', '(%.1f, %.1f)' % (dim_mean-1.96*dim_std, dim_mean+1.96*dim_std))\n\n\ndef calculate_intrinsic_dimension_statistics_all_methods(dataset, model_type='model', if_image=False):\n    configs_dir = os.path.join('../configs', dataset, model_type)\n    if if_image:\n        filename = 'intrinsic_dimension_image_all_methods.npy'\n    else:\n        filename = 'intrinsic_dimension_all_methods.npy'\n    all_methods = ['Levina_Bickel', 'MiND_ML', 'MiND_KL', 'Hein', 'CD']\n    dims_all = {method:[] for method in all_methods}\n    for config_filepath in os.listdir(configs_dir):\n        cfg = load_config(filepath=os.path.join(configs_dir, config_filepath))\n        cfg = munchify(cfg)\n        log_dir = '_'.join([cfg.log_dir, cfg.dataset, cfg.model_name, str(cfg.seed)])\n        vars_filepath = os.path.join(log_dir, 'variables')\n        dims = np.load(os.path.join(vars_filepath, filename), allow_pickle=True).item()\n        for method in all_methods:\n            dims_all[method].append(dims[method])\n    for method in all_methods:\n        if method not in ['Hein', 'CD']:\n            dims_all[method] = np.concatenate(dims_all[method])\n        dim_mean = np.mean(dims_all[method])\n        dim_std = np.std(dims_all[method])\n        print(method + ':', '%.2f (±%.2f)' % (dim_mean, dim_std))\n\n\nif __name__ == '__main__':\n    dataset = str(sys.argv[1])\n    calculate_intrinsic_dimension_statistics(dataset, 'model')\n    # calculate_intrinsic_dimension_statistics(dataset, 'model', if_image=True)\n    # calculate_intrinsic_dimension_statistics_all_methods(dataset, 'model')"
  }
]