Repository: BoyuanChen/neural-state-variables
Branch: main
Commit: 10483d93ac8c
Files: 227
Total size: 570.0 KB
Directory structure:
gitextract_ivg233ui/
├── .gitignore
├── LICENSE
├── README.md
├── analysis/
│ ├── README.md
│ ├── eval_intrinsic_dimension.py
│ ├── eval_phys_data.py
│ ├── eval_phys_double_pendulum/
│ │ ├── __init__.py
│ │ ├── angle_estimator.py
│ │ └── physics_estimator.py
│ ├── eval_phys_elastic_pendulum/
│ │ ├── __init__.py
│ │ ├── angle_estimator.py
│ │ └── physics_estimator.py
│ ├── eval_phys_long_term_pred.py
│ ├── eval_phys_single_pendulum/
│ │ ├── __init__.py
│ │ ├── angle_estimator.py
│ │ └── physics_estimator.py
│ ├── eval_regression.py
│ ├── intrinsic_dimension_estimation/
│ │ ├── __init__.py
│ │ ├── matlab_codes/
│ │ │ ├── DANCo.m
│ │ │ ├── DANCoFit.m
│ │ │ ├── DANCoTrain.m
│ │ │ ├── DANCo_fits.mat
│ │ │ ├── GetDim.mexw64
│ │ │ ├── KNN.m
│ │ │ ├── MLE.m
│ │ │ ├── MiND_KL.m
│ │ │ ├── MiND_ML.m
│ │ │ ├── demo_idEstimation.m
│ │ │ ├── gauss.m
│ │ │ ├── html/
│ │ │ │ └── demo_idEstimation.html
│ │ │ ├── linSubspSpanOrthonormalize.m
│ │ │ ├── parseParamsNamed.m
│ │ │ ├── private/
│ │ │ │ ├── DANCo_estimateKL.m
│ │ │ │ └── DANCo_statistics.m
│ │ │ └── randsphere.m
│ │ └── methods.py
│ └── latent_regression/
│ ├── __init__.py
│ └── regressors.py
├── configs/
│ ├── air_dancer/
│ │ ├── latentpred/
│ │ │ ├── config1.yaml
│ │ │ ├── config2.yaml
│ │ │ └── config3.yaml
│ │ ├── model/
│ │ │ ├── config1.yaml
│ │ │ ├── config2.yaml
│ │ │ └── config3.yaml
│ │ ├── model64/
│ │ │ ├── config1.yaml
│ │ │ ├── config2.yaml
│ │ │ └── config3.yaml
│ │ └── refine64/
│ │ ├── config1.yaml
│ │ ├── config2.yaml
│ │ └── config3.yaml
│ ├── circular_motion/
│ │ ├── model/
│ │ │ ├── config1.yaml
│ │ │ ├── config2.yaml
│ │ │ └── config3.yaml
│ │ ├── model64/
│ │ │ ├── config1.yaml
│ │ │ ├── config2.yaml
│ │ │ └── config3.yaml
│ │ └── refine64/
│ │ ├── config1.yaml
│ │ ├── config2.yaml
│ │ └── config3.yaml
│ ├── double_pendulum/
│ │ ├── latentpred/
│ │ │ ├── config1.yaml
│ │ │ ├── config2.yaml
│ │ │ └── config3.yaml
│ │ ├── model/
│ │ │ ├── config1.yaml
│ │ │ ├── config2.yaml
│ │ │ └── config3.yaml
│ │ ├── model64/
│ │ │ ├── config1.yaml
│ │ │ ├── config2.yaml
│ │ │ └── config3.yaml
│ │ └── refine64/
│ │ ├── config1.yaml
│ │ ├── config2.yaml
│ │ └── config3.yaml
│ ├── elastic_pendulum/
│ │ ├── latentpred/
│ │ │ ├── config1.yaml
│ │ │ ├── config2.yaml
│ │ │ └── config3.yaml
│ │ ├── model/
│ │ │ ├── config1.yaml
│ │ │ ├── config2.yaml
│ │ │ └── config3.yaml
│ │ ├── model64/
│ │ │ ├── config1.yaml
│ │ │ ├── config2.yaml
│ │ │ └── config3.yaml
│ │ └── refine64/
│ │ ├── config1.yaml
│ │ ├── config2.yaml
│ │ └── config3.yaml
│ ├── fire/
│ │ ├── latentpred/
│ │ │ ├── config1.yaml
│ │ │ ├── config2.yaml
│ │ │ └── config3.yaml
│ │ ├── model/
│ │ │ ├── config1.yaml
│ │ │ ├── config2.yaml
│ │ │ └── config3.yaml
│ │ ├── model64/
│ │ │ ├── config1.yaml
│ │ │ ├── config2.yaml
│ │ │ └── config3.yaml
│ │ └── refine64/
│ │ ├── config1.yaml
│ │ ├── config2.yaml
│ │ └── config3.yaml
│ ├── lava_lamp/
│ │ ├── latentpred/
│ │ │ ├── config1.yaml
│ │ │ ├── config2.yaml
│ │ │ └── config3.yaml
│ │ ├── model/
│ │ │ ├── config1.yaml
│ │ │ ├── config2.yaml
│ │ │ └── config3.yaml
│ │ ├── model64/
│ │ │ ├── config1.yaml
│ │ │ ├── config2.yaml
│ │ │ └── config3.yaml
│ │ └── refine64/
│ │ ├── config1.yaml
│ │ ├── config2.yaml
│ │ └── config3.yaml
│ ├── reaction_diffusion/
│ │ ├── latentpred/
│ │ │ ├── config1.yaml
│ │ │ ├── config2.yaml
│ │ │ └── config3.yaml
│ │ ├── model/
│ │ │ ├── config1.yaml
│ │ │ ├── config2.yaml
│ │ │ └── config3.yaml
│ │ ├── model64/
│ │ │ ├── config1.yaml
│ │ │ ├── config2.yaml
│ │ │ └── config3.yaml
│ │ └── refine64/
│ │ ├── config1.yaml
│ │ ├── config2.yaml
│ │ └── config3.yaml
│ ├── single_pendulum/
│ │ ├── latentpred/
│ │ │ ├── config1.yaml
│ │ │ ├── config2.yaml
│ │ │ └── config3.yaml
│ │ ├── model/
│ │ │ ├── config1.yaml
│ │ │ ├── config2.yaml
│ │ │ └── config3.yaml
│ │ ├── model64/
│ │ │ ├── config1.yaml
│ │ │ ├── config2.yaml
│ │ │ └── config3.yaml
│ │ └── refine64/
│ │ ├── config1.yaml
│ │ ├── config2.yaml
│ │ └── config3.yaml
│ ├── swingstick_magnetic/
│ │ ├── model/
│ │ │ ├── config1.yaml
│ │ │ ├── config2.yaml
│ │ │ └── config3.yaml
│ │ ├── model64/
│ │ │ ├── config1.yaml
│ │ │ ├── config2.yaml
│ │ │ └── config3.yaml
│ │ └── refine64/
│ │ ├── config1.yaml
│ │ ├── config2.yaml
│ │ └── config3.yaml
│ └── swingstick_non_magnetic/
│ ├── latentpred/
│ │ ├── config1.yaml
│ │ ├── config2.yaml
│ │ └── config3.yaml
│ ├── model/
│ │ ├── config1.yaml
│ │ ├── config2.yaml
│ │ └── config3.yaml
│ ├── model64/
│ │ ├── config1.yaml
│ │ ├── config2.yaml
│ │ └── config3.yaml
│ └── refine64/
│ ├── config1.yaml
│ ├── config2.yaml
│ └── config3.yaml
├── datainfo/
│ ├── README.md
│ ├── air_dancer/
│ │ ├── data_split_dict_1.json
│ │ ├── data_split_dict_2.json
│ │ └── data_split_dict_3.json
│ ├── circular_motion/
│ │ ├── data_split_dict_1.json
│ │ ├── data_split_dict_2.json
│ │ └── data_split_dict_3.json
│ ├── collect/
│ │ ├── circular_motion/
│ │ │ └── make_data.py
│ │ ├── double_pendulum/
│ │ │ ├── README.md
│ │ │ ├── convert_video.py
│ │ │ ├── equalize_background.py
│ │ │ └── split_data.py
│ │ ├── elastic_pendulum/
│ │ │ └── make_data.py
│ │ ├── fire/
│ │ │ └── split_data.py
│ │ ├── lava_lamp/
│ │ │ └── split_data.py
│ │ ├── reaction_diffusion/
│ │ │ ├── make_data.py
│ │ │ ├── reaction_diffusion.m
│ │ │ └── reaction_diffusion_rhs.m
│ │ ├── single_pendulum/
│ │ │ └── make_data.py
│ │ └── utils/
│ │ └── sort_vids.py
│ ├── double_pendulum/
│ │ ├── data_split_dict_1.json
│ │ ├── data_split_dict_2.json
│ │ └── data_split_dict_3.json
│ ├── elastic_pendulum/
│ │ ├── data_split_dict_1.json
│ │ ├── data_split_dict_2.json
│ │ └── data_split_dict_3.json
│ ├── fire/
│ │ ├── data_split_dict_1.json
│ │ ├── data_split_dict_2.json
│ │ └── data_split_dict_3.json
│ ├── lava_lamp/
│ │ ├── data_split_dict_1.json
│ │ ├── data_split_dict_2.json
│ │ └── data_split_dict_3.json
│ ├── reaction_diffusion/
│ │ ├── data_split_dict_1.json
│ │ ├── data_split_dict_2.json
│ │ └── data_split_dict_3.json
│ ├── single_pendulum/
│ │ ├── data_split_dict_1.json
│ │ ├── data_split_dict_2.json
│ │ └── data_split_dict_3.json
│ ├── swingstick_magnetic/
│ │ ├── data_split_dict_1.json
│ │ ├── data_split_dict_2.json
│ │ └── data_split_dict_3.json
│ └── swingstick_non_magnetic/
│ ├── data_split_dict_1.json
│ ├── data_split_dict_2.json
│ └── data_split_dict_3.json
├── dataset.py
├── eval.py
├── main.py
├── model_utils.py
├── models.py
├── models_latentpred.py
├── pred.py
├── requirements.txt
├── scripts/
│ ├── encoder_decoder_64_eval.sh
│ ├── encoder_decoder_64_eval_gather.sh
│ ├── encoder_decoder_64_long_term_model_rollout.sh
│ ├── encoder_decoder_64_long_term_model_rollout_perturb_all.sh
│ ├── encoder_decoder_64_train.sh
│ ├── encoder_decoder_estimate_dimension.sh
│ ├── encoder_decoder_eval.sh
│ ├── encoder_decoder_eval_gather.sh
│ ├── encoder_decoder_long_term_model_rollout.sh
│ ├── encoder_decoder_long_term_model_rollout_perturb_all.sh
│ ├── encoder_decoder_train.sh
│ ├── latentpred_train.sh
│ ├── long_term_eval_stability.sh
│ ├── long_term_eval_stability_hybrid.sh
│ ├── refine_64_eval.sh
│ ├── refine_64_eval_gather.sh
│ ├── refine_64_long_term_hybrid_rollout.sh
│ ├── refine_64_long_term_model_rollout.sh
│ ├── refine_64_long_term_model_rollout_perturb_all.sh
│ └── refine_64_train.sh
├── stability.py
└── utils/
├── common.py
└── dimension.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
*.*~
*~
__pycache__
scripts/logs_*
================================================
FILE: LICENSE
================================================
MIT License
Copyright (c) 2021 Boyuan Chen
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
================================================
FILE: README.md
================================================
# Discovering State Variables Hidden in Experimental Data
[Boyuan Chen](http://boyuanchen.com/),
[Kuang Huang](https://cm3.apam.columbia.edu/people/kuang-huang),
[Sunand Raghupathi](https://www.linkedin.com/in/sunand-raghupathi?trk=public_profile_browsemap),
[Ishaan Chandratreya](https://www.linkedin.com/in/ishaan-chandratreya-546aa6141),
[Qiang Du](https://www.engineering.columbia.edu/faculty/qiang-du),
[Hod Lipson](https://www.hodlipson.com/)
Columbia University
### [Project Website](https://www.cs.columbia.edu/~bchen/neural-state-variables) | [Video](https://youtu.be/KWHXchlJzSw) | [Paper](http://arxiv.org/abs/2112.10755)
## Overview
This repo contains the PyTorch implementation for paper "Discovering State Variables Hidden in Experimental Data".

## Citation
If you find our paper or codebase helpful, please consider citing:
```
@article{chen2021discover,
title={Discovering State Variables Hidden in Experimental Data},
author={Chen, Boyuan and Huang, Kuang and Raghupathi, Sunand and Chandratreya, Ishaan and Du, Qiang and Lipson, Hod},
journal={arXiv preprint arXiv:2112.10755},
year={2021}
}
```
## Content
- [Installation](#installation)
- [Logging](#logging)
- [Data Preparation](#data-preparation)
- [Training and Testing](#training-and-testing)
- [Intrinsic Dimension Estimation](#intrinsic-dimension-estimation)
- [Long-term Prediction and Stability Evaluation](#long-term-prediction-and-stability-evaluation)
- [Evaluation and Analysis](#evaluation-analysis)
- [License](#license)
## Installation
The installation has been test on Ubuntu 18.04 with CUDA 11.0. All the experiments are performed on one GeForce RTX 2080 Ti Nvidia GPU.
Create a python virtual environment and install the dependencies.
```
virtualenv -p /usr/bin/python3.6 env3.6
source env3.6/bin/activate
pip install -r requirements.txt
```
**Note**: You may need to install Matlab on your computer to use some of the data collectors and intrinsic dimension estimation algorithms.
## Logingg
We first introduce the naming convention of the saved files so that it is clear what will be saved and where they will be saved.
1. Log folder naming convention:
```
logs_{dataset}_{model_name}_{seed}
```
2. Inside the logs folder, the structure and contents are:
```
\logs_{dataset}_{model_name}_{seed}
\lightning_logs
\checkpoints [saved checkpoint]
\version_0 [training stats]
\version_1 [testing stats]
\predictions [testing predicted images]
\prediction_long_term [long term predicted images]
\variables [file id and latent vectors on testing data]
\variables_train [file id and latent vectors on training data]
\variables_val [file id and latent vectors on validation data]
```
## Data Preparation
We provide nine datasets with their own download links below.
- [circular_motion](https://drive.google.com/file/d/19DFYqh08B2L-YX_bTmpmxYu2jpARqoUA/view?usp=sharing) (circular motion system)
- [reaction_diffusion](https://drive.google.com/file/d/1LOqpB0l86lELEegX5sV9NwHkGKeHoaTN/view?usp=sharing) (reaction diffusion system)
- [single_pendulum](https://drive.google.com/file/d/1rw6vrVcV3KCIaVVwKMMhoZ9Jph5u-Zdf/view?usp=sharing) (single pendulum system)
- [double_pendulum](https://drive.google.com/file/d/1QEtk4JjnRysIEjtkZIKBjACu_IiTAdX6/view?usp=sharing) (rigid double pendulum system)
- [elastic_pendulum](https://drive.google.com/file/d/1Y8vzawQZhzPHp6cpLFuyM2LHuhgLjH9H/view?usp=sharing) (elastic double pendulum system)
- [swingstick_non_magnetic](https://drive.google.com/file/d/1BfeGW4XTFyGdyBO0G_YnSnRyJGu2WRnc/view?usp=sharing) (swing stick system)
- [air_dancer](https://drive.google.com/file/d/163KvwevY1fnDnI6WiWpc5Zaz-WWwPaAS/view?usp=sharing) (air dancer system)
- [lava_lamp](https://drive.google.com/file/d/1R-I2CZaJLe2D4H818-n25l54R0pbKQfo/view?usp=sharing) (lava lamp system)
- [fire](https://drive.google.com/file/d/1OIH6SalwPyD_lkXBLZq6bzmKfrp-xhrI/view?usp=sharing) (fire system)
Save the downloaded dataset as ```data/{dataset_name}```, where ```data``` is your customized dataset folder. **Please make sure that ```data``` is an absolute path and you need to change the ```data_filepath``` item in the ```config.yaml``` files in ```configs``` to specify your customized dataset folder**.
Please refer to the [datainfo](datainfo) folder for more details about data structure and dataset collection process.
## Training and Testing
Our approach involves three models:
- dynamics predictive model (encoder-decoder / encoder-decoder-64)
- latent reconstruction model (refine-64)
- neural latent dynamics model (latentpred)
1. Navigate to the scripts folder
```
cd scripts
```
2. Train the dynamics predictive model (encoder-decoder and encoder-decoder-64) and then save the high-dimensional latent vectors from the testing data.
```
./encoder_decoder_64_train.sh {dataset_name} {gpu no.}
./encoder_decoder_train.sh {dataset_name} {gpu no.}
./encoder_decoder_64_eval.sh {dataset_name} {gpu no.}
./encoder_decoder_eval.sh {dataset_name} {gpu no.}
```
3. Run forward pass on the trained encoder-deocer model and encoder-decoder-64 model to save the high-dimensional latent vectors from the training and validation data. The saved latent vectors will be used as the training and validataion data for training and validating the latent reconstruction model.
```
./encoder_decoder_eval_gather.sh {dataset_name} {gpu no.}
./encoder_decoder_64_eval_gather.sh {dataset_name} {gpu no.}
```
4. Before you proceed this step, please refer to the next section to obtain the system's intrinsic dimension and then come back to this step.
Train the latent reconstruction model (refine-64) with the saved 64-dim latent vectors from previous steps. Then save the obtained Neural State Variables from both training and testing data.
```
./refine_64_train.sh {dataset_name} {gpu no.}
./refine_64_eval.sh {dataset_name} {gpu no.}
./refine_64_eval_gather.sh {dataset_name} {gpu no.}
```
5. Train the neural latent dynamics model (latentpred) with the trained models from previous steps.
```
./latentpred_train.sh single_pendulum {dataset_name} {gpu no.}
```
## Intrinsic Dimension Estimation
With the trained dynamics predictive model, our approach provides subroutines to estimate the system's intrinsic dimension (ID) using manifold learning algorithms. The estimated intrinsic dimension will be used to decide the number of Neural State Variables and to design the latent reconstruction model. Only after this step you can proceed to train the latent reconstruction model to obtain Neural State Variables.
1. Navigate to the scripts folder
```
cd scripts
```
which is the default directory saving all models' log folders.
2. Estimate the intrinsic dimension from the saved latent vectors of the encoder-decoder models for all random seeds.
```
./encoder_decoder_estimate_dimension.sh {dataset_name}
```
3. Calculate the final intrinsic dimension estimated values (mean and standard deviation).
```
python ../utils/dimension.py {dataset_name}
```
## Long-term Prediction and Stability Evaluation
With all above trained models, our approach offers system long-term predictions through model rollouts as well as stability evaluation of the long-term predictions.
1. Navigate to the scripts folder
```
cd scripts
```
2. Long-term prediction with single model rollouts.
```
./encoder_decoder_long_term_model_rollout.sh {dataset_name} {gpu no.}
./encoder_decoder_64_long_term_model_rollout.sh {dataset_name} {gpu no.}
./refine_64_long_term_model_rollout.sh {dataset_name} {gpu no.}
```
The predictions will be saved in the ```prediction_long_term``` subfolder under the model's log folder.
3. Long-term prediction with hybrid model rollouts.
```
./refine_64_long_term_hybrid_rollout.sh {dataset_name} {gpu no.} {step}
```
where ```step``` is the number of model rollouts via 64-dim latent vectors before a model rollout via Neural State Variables.
The predictions will be saved in the ```prediction_long_term``` subfolder under the refine-64 model's log folder.
4. Long-term prediction with single model rollouts from perturbed initial frames.
```
./encoder_decoder_long_term_model_rollout_perturb_all.sh {dataset_name} {gpu no.}
./encoder_decoder_64_long_term_model_rollout_perturb_all.sh {dataset_name} {gpu no.}
./refine_64_long_term_model_rollout_perturb_all.sh {dataset_name} {gpu no.}
```
The predictions will be saved in the ```prediction_long_term``` subfolder under the model's log folder.
3. Stability evaluation on long-term predictions with single model rollouts
```
./long_term_eval_stability.sh {dataset_name} {gpu no.}
```
and on long-term predictions with hybrid model rollouts
```
./long_term_eval_stability_hybrid.sh {dataset_name} {gpu no.} {step}
```
where ```step``` is as mentioned above for hybrid model rollouts.
The evaluated latent space errors measuring the prediction stability will be saved in the ```stability.npy``` file under the respective ```prediction_long_term``` folders.
**Note**: The default long-term prediction length is 60 (in frames). You will need to modify the scripts if you want to use a different prediction length.
## Evaluation and Analysis
Please refer to the [analysis](analysis) folder for detailed instructions for physical evaluation and analysis.
## License
This repository is released under the MIT license. See [LICENSE](LICENSE) for additional details.
================================================
FILE: analysis/README.md
================================================
## Physics Evaluation on Known Systems
We provide physics evaluation algorithms extracting physical variables including positions, velocities, and energies from ground truth and predicted frames for the three known systems (single pendulum, rigid double pendulum, and elastic double pendulum).
You can use these algorithms to evaluate the physical accuracy of predictions.
1. Evaluate physical variables from ground truth data.
```
python eval_phys_data.py dataset_name data_filepath
```
The ```dataset_name``` can be single_pendulum, double_pendulum, or elastic_pendulum. The ```data_filepath``` is the respective data filepath, e.g., ```data/single_pendulum``` where ```data``` is your customized data folder. The results will be saved in the ```phys_vars.npy``` file under ```data_filepath```.
2. Evaluate physical variables from long-term predictions via model rollouts and compute physical errors between ground truth data and the predictions.
```
python eval_phys_long_term_pred.py config_filepath pred_save_path
```
The ```config_filepath``` is the model configuration filepath, e.g., ```../configs/single_pendulum/model/config1.yaml```. The ```pred_save_path``` is the path to the predicted images, e.g., ```../scripts/logs_single_pendulum_encoder-decoder_1/prediction_long_term/model_rollout/```. The results will be saved in the ```phys_vars.npy```, ```phys_error.npy```, and ```pixel_error.npy``` files under ```pred_save_path```.
**Note**: ```eval_phys_single_pendulum```, ```eval_phys_double_pendulum```, and ```eval_phys_elastic_pendulum``` are helper packages for evaluating physical variables for the above command lines.
## Intrinsic Dimension Estimation
1. Navigate to the scripts folder
```
cd ../scripts
```
which is the default directory saving all models' log folders.
2. Estimate intrinsic dimension from model latent vectors using the Levina-Bickel method
```
python ../analysis/eval_intrinsic_dimension.py {config_filepath} model-latent NA
```
or using all methods (Levina-Bickel, MiND-ML, MiND-KL, Hein, CD)
```
python ../analysis/eval_intrinsic_dimension.py {config_filepath} model-latent all-methods
```
The ```config_filepath``` is the model configuration filepath, e.g., ```../configs/single_pendulum/model/config1.yaml```. The results will be saved in the ```intrinsic_dimension.npy``` file (using the Levina-Bickel method) or the ```intrinsic_dimension_all_methods.npy``` file (using all methods) in the ```variables``` subfolder under the model's log folder.
3. Estimate intrinsic dimension from raw data (image pairs) using the Levina-Bickel method
```
python ../analysis/eval_intrinsic_dimension.py {config_filepath} data-image NA
```
or using all methods (Levina-Bickel, MiND-ML, MiND-KL, Hein, CD)
```
python ../analysis/eval_intrinsic_dimension.py {config_filepath} data-image all-methods
```
The ```config_filepath``` is the model configuration filepath, e.g., ```../configs/single_pendulum/model/config1.yaml```. The results will be saved in the ```intrinsic_dimension_image.npy``` file (using the Levina-Bickel method) or the ```intrinsic_dimension_image_all_methods.npy``` file (using all methods) in the ```variables``` subfolder under the model's log folder. Here the model configuration only provides data filepath and test video ids.
**Note**: ```intrinsic_dimension_estimation``` are helper packages providing various intrinsic dimension estimation methods. When choosing the ```all-methods``` option, the MATLAB and MATLAB Engine API for Python should be installed on the machine, see instructions [here](https://www.mathworks.com/help/matlab/matlab_external/install-the-matlab-engine-for-python.html).
References of MATLAB codes:
```
[1] Gabriele Lombardi (2021). Intrinsic dimensionality estimation techniques (https://www.mathworks.com/matlabcentral/fileexchange/40112-intrinsic-dimensionality-estimation-techniques), MATLAB Central File Exchange. Retrieved October 21, 2021.
[2] M. Hein and J.-Y. Audibert, Intrinsic dimensionality estimation of submanifolds in Euclidean space, Proceedings of the 22nd Internatical Conference on Machine Learning (https://www.ml.uni-saarland.de/code/IntDim/IntDim.htm).
```
## Latent Space Regression on Known Systems
1. Navigate to the scripts folder
```
cd ../scripts
```
which is the default directory saving all models' log folders.
2. Regress physical variables from the model latent vectors
```
python ../analysis/eval_regression.py config_filepath NA
```
or from the first ```num_components``` principal components of the model latent vectors
```
python ../analysis/eval_regression.py config_filepath num_components
```
The ```config_filepath``` is the model configuration filepath, e.g., ```../configs/single_pendulum/model/config1.yaml```. The results will be saved in the ```regression_results.npy``` file (with model latent vectors) or the ```regression_results_pca_{num_components}.npy``` file (with first few principal components of model latent vectors) in the ```variables``` subfolder under the model's log folder.
================================================
FILE: analysis/eval_intrinsic_dimension.py
================================================
import numpy as np
import os
import sys
from tqdm import tqdm
from PIL import Image
import json
import yaml
import pprint
from munch import munchify
from intrinsic_dimension_estimation import ID_Estimator
def load_config(filepath):
with open(filepath, 'r') as stream:
try:
trainer_params = yaml.safe_load(stream)
return trainer_params
except yaml.YAMLError as exc:
print(exc)
def remove_duplicates(X):
return np.unique(X, axis=0)
def eval_id_latent(vars_filepath, if_refine, if_all_methods):
if if_refine:
latent = np.load(os.path.join(vars_filepath, 'refine_latent.npy'))
else:
latent = np.load(os.path.join(vars_filepath, 'latent.npy'))
print(f'Number of samples: {latent.shape[0]}; Latent dimension: {latent.shape[1]}')
latent = remove_duplicates(latent)
print(f'Number of samples (duplicates removed): {latent.shape[0]}')
estimator = ID_Estimator()
k_list = (latent.shape[0] * np.linspace(0.008, 0.016, 5)).astype('int')
print(f'List of numbers of nearest neighbors: {k_list}')
if if_all_methods:
dims = estimator.fit_all_methods(latent, k_list)
np.save(os.path.join(vars_filepath, 'intrinsic_dimension_all_methods.npy'), dims)
else:
dims = estimator.fit(latent, k_list)
np.save(os.path.join(vars_filepath, 'intrinsic_dimension.npy'), dims)
def eval_id_image(data_filepath, test_vid_ids, vars_filepath, if_all_methods):
print('Reading image data...')
data = []
for p_vid_idx in tqdm(test_vid_ids):
vid_filepath = os.path.join(data_filepath, str(p_vid_idx))
num_frames = len(os.listdir(vid_filepath))
suf = os.listdir(vid_filepath)[0].split('.')[-1]
for p_frame in range(num_frames - 3):
img_list = []
for p in range(2):
img = Image.open(os.path.join(vid_filepath, str(p_frame + p) + '.' + suf))
img = img.resize((128, 128))
img = np.array(img) / 255.0
img_list.append(img)
data_p = np.concatenate(img_list, 1).reshape([1, -1])
data.append(data_p)
data = np.concatenate(data, 0)
print(f'Number of samples: {data.shape[0]}; Image dimension (flattened): {data.shape[1]}')
data = remove_duplicates(data)
print(f'Number of samples (duplicates removed): {data.shape[0]}')
estimator = ID_Estimator()
k_list = (data.shape[0] * np.linspace(0.008, 0.016, 5)).astype('int')
print(f'List of numbers of nearest neighbors: {k_list}')
if if_all_methods:
dims = estimator.fit_all_methods(data, k_list)
np.save(os.path.join(vars_filepath, 'intrinsic_dimension_image_all_methods.npy'), dims)
else:
dims = estimator.fit(data, k_list)
np.save(os.path.join(vars_filepath, 'intrinsic_dimension_image.npy'), dims)
if __name__ == '__main__':
config_filepath = str(sys.argv[1])
cfg = load_config(filepath=config_filepath)
pprint.pprint(cfg)
cfg = munchify(cfg)
if_all_methods = (str(sys.argv[3]) == 'all-methods')
if str(sys.argv[2]) == 'model-latent':
log_dir = '_'.join([cfg.log_dir, cfg.dataset, cfg.model_name, str(cfg.seed)])
vars_filepath = os.path.join(log_dir, 'variables')
if_refine = ('refine' in cfg.model_name)
dims = eval_id_latent(vars_filepath, if_refine, if_all_methods)
elif str(sys.argv[2]) == 'data-image':
data_filepath = os.path.join(cfg.data_filepath, cfg.dataset)
with open(os.path.join('../datainfo', cfg.dataset, f'data_split_dict_{cfg.seed}.json'), 'r') as file:
seq_dict = json.load(file)
test_vid_ids = seq_dict['test']
log_dir = '_'.join([cfg.log_dir, cfg.dataset, cfg.model_name, str(cfg.seed)])
vars_filepath = os.path.join(log_dir, 'variables')
dims = eval_id_image(data_filepath, test_vid_ids, vars_filepath, if_all_methods)
else:
assert False, 'Invalid arguments...'
================================================
FILE: analysis/eval_phys_data.py
================================================
import os
import sys
import cv2
import numpy as np
from tqdm import tqdm
def eval_phys_data_single_pendulum(data_filepath, num_vids, num_frms, save_path):
from eval_phys_single_pendulum import eval_physics, phys_vars_list
phys = {p_var:[] for p_var in phys_vars_list}
for n in tqdm(range(num_vids)):
seq_filepath = os.path.join(data_filepath, str(n))
frames = []
for p in range(num_frms):
frame_p = cv2.imread(os.path.join(seq_filepath, str(p)+'.png'))
frames.append(frame_p)
phys_tmp = eval_physics(frames)
for p_var in phys_vars_list:
phys[p_var].append(phys_tmp[p_var])
for p_var in phys_vars_list:
phys[p_var] = np.array(phys[p_var])
np.save(save_path, phys)
def eval_phys_data_double_pendulum(data_filepath, num_vids, num_frms, save_path):
from eval_phys_double_pendulum import eval_physics, phys_vars_list
phys = {p_var:[] for p_var in phys_vars_list}
for n in tqdm(range(num_vids)):
seq_filepath = os.path.join(data_filepath, str(n))
frames = []
for p in range(num_frms):
frame_p = cv2.imread(os.path.join(seq_filepath, str(p)+'.png'))
frames.append(frame_p)
phys_tmp = eval_physics(frames)
for p_var in phys_vars_list:
phys[p_var].append(phys_tmp[p_var])
for p_var in phys_vars_list:
phys[p_var] = np.array(phys[p_var])
# remove outliers
thresh_1 = np.nanpercentile(np.abs(phys['vel_theta_1']), 98)
thresh_2 = np.nanpercentile(np.abs(phys['vel_theta_2']), 98)
for n in range(num_vids):
for p in range(num_frms):
if (not np.isnan(phys['vel_theta_1'][n, p]) and np.abs(phys['vel_theta_1'][n, p]) >= thresh_1) \
or (not np.isnan(phys['vel_theta_2'][n, p]) and np.abs(phys['vel_theta_2'][n, p]) >= thresh_2):
phys['vel_theta_1'][n, p] = np.nan
phys['vel_theta_2'][n, p] = np.nan
phys['kinetic energy'][n, p] = np.nan
phys['total energy'][n, p] = np.nan
np.save(save_path, phys)
def eval_phys_data_elastic_pendulum(data_filepath, num_vids, num_frms, save_path):
from eval_phys_elastic_pendulum import eval_physics, phys_vars_list
phys = {p_var:[] for p_var in phys_vars_list}
for n in tqdm(range(num_vids)):
seq_filepath = os.path.join(data_filepath, str(n))
frames = []
for p in range(num_frms):
frame_p = cv2.imread(os.path.join(seq_filepath, str(p)+'.png'))
frames.append(frame_p)
phys_tmp = eval_physics(frames)
for p_var in phys_vars_list:
phys[p_var].append(phys_tmp[p_var])
for p_var in phys_vars_list:
phys[p_var] = np.array(phys[p_var])
# remove outliers
thresh_1 = np.nanpercentile(np.abs(phys['vel_theta_1']), 98)
thresh_2 = np.nanpercentile(np.abs(phys['vel_theta_2']), 98)
thresh_z = np.nanpercentile(np.abs(phys['vel_z']), 98)
for n in range(num_vids):
for p in range(num_frms):
if (not np.isnan(phys['vel_theta_1'][n, p]) and np.abs(phys['vel_theta_1'][n, p]) >= thresh_1) \
or (not np.isnan(phys['vel_theta_2'][n, p]) and np.abs(phys['vel_theta_2'][n, p]) >= thresh_2) \
or (not np.isnan(phys['vel_z'][n, p]) and np.abs(phys['vel_z'][n, p]) >= thresh_z):
phys['vel_theta_1'][n, p] = np.nan
phys['vel_theta_2'][n, p] = np.nan
phys['vel_z'][n, p] = np.nan
phys['kinetic energy'][n, p] = np.nan
phys['total energy'][n, p] = np.nan
np.save(save_path, phys)
if __name__ == '__main__':
dataset = str(sys.argv[1])
data_filepath = str(sys.argv[2])
save_path = os.path.join(data_filepath, 'phys_vars.npy')
if dataset == 'single_pendulum':
eval_phys_data_single_pendulum(data_filepath, 1200, 60, save_path)
elif dataset == 'double_pendulum':
eval_phys_data_double_pendulum(data_filepath, 1100, 60, save_path)
elif dataset == 'elastic_pendulum':
eval_phys_data_elastic_pendulum(data_filepath, 1200, 60, save_path)
else:
assert False, 'Unknown system...'
================================================
FILE: analysis/eval_phys_double_pendulum/__init__.py
================================================
from .angle_estimator import obtain_angle
from .physics_estimator import *
import numpy as np
phys_vars_list = {'reject', 'theta_1', 'vel_theta_1', 'theta_2', 'vel_theta_2', 'kinetic energy',
'potential energy', 'total energy'}
def eval_physics(frames):
num_frames = len(frames)
reject = np.zeros(num_frames, dtype=bool)
theta_1 = np.zeros(num_frames)
theta_2 = np.zeros(num_frames)
# estimate angles
for p in range(num_frames):
reject_p, angles_p, _ = obtain_angle(frames[p])
if reject_p:
reject[p] = True
theta_1[p] = np.nan
theta_2[p] = np.nan
else:
reject[p] = False
theta_1[p] = angles_p[0]
theta_2[p] = angles_p[1]
# calculate velocities
vel_theta_1 = np.zeros(num_frames)
vel_theta_2 = np.zeros(num_frames)
sub_ids = np.ma.clump_unmasked(np.ma.masked_array(theta_1, reject))
for ids in sub_ids:
vel_theta_1[ids] = calc_velocity(theta_1[ids].copy())
vel_theta_2[ids] = calc_velocity(theta_2[ids].copy())
vel_theta_1[reject] = np.nan
vel_theta_2[reject] = np.nan
# calculate energies
kinetic_energy, potential_energy, total_energy = calc_energy(theta_1, theta_2, vel_theta_1, vel_theta_2)
# save results
phys = dict.fromkeys(phys_vars_list)
phys['reject'] = reject
phys['theta_1'] = theta_1
phys['vel_theta_1'] = vel_theta_1
phys['theta_2'] = theta_2
phys['vel_theta_2'] = vel_theta_2
phys['kinetic energy'] = kinetic_energy
phys['potential energy'] = potential_energy
phys['total energy'] = total_energy
return phys
================================================
FILE: analysis/eval_phys_double_pendulum/angle_estimator.py
================================================
'''
This script provides utility functions estimating angles of a double pendulum from image.
Step 1. Extract each of the two pendulums from the image.
Step 2. Do a rectangle fitting of each pendulum.
Step 3. Estimate the angles of the pendulums.
Certain images will be rejected if any pendulum does not exist, is hidden, or has a wrong shape.
'''
import cv2
import os
import numpy as np
'''
Extract the two pendulums from the image.
Args:
img: double pendulum image in BGR format
Returns:
seg_1: segmentation of the first pendulum
seg_2: segmentation of the second pendulum
'''
def seg_from_img(img):
# pixel thresholds (in HSV)
v_min_1 = (0, 0, 0)
v_max_1 = (255, 255, 140)
v_min_2 = (60, 0, 0)
v_max_2 = (255, 255, 255)
# background color in HSV (RGB: [215, 205, 192])
bg_color = [17, 27, 215]
img_hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
# extract the first pendulum
seg_1 = cv2.inRange(img_hsv, v_min_1, v_max_1)
# remove the first pendulum by replacing it with background pixels
img_hsv[seg_1==255] = bg_color
# extract the second pendulum
seg_2 = cv2.inRange(img_hsv, v_min_2, v_max_2)
return seg_1, seg_2
'''
Fit pendulum in a rectangle.
Args:
seg: segmentation of the pendulum
Returns:
rej: (True/False) if the image is rejected
rect: (Box2D structure) the fitted rectangle
'''
def fit_pendulum(seg):
# find all contours
contours, _ = cv2.findContours(seg, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
# reject if no contours found
if len(contours) == 0:
return True, None
# find the contour with the maximum area
cnt = max(contours, key=cv2.contourArea)
area = cv2.contourArea(cnt)
# reject if the contour is too small
if area < 125:
return True, None
# rectangle fitting
rect = cv2.minAreaRect(cnt)
_, (width, height), _ = rect
# reject if the rectangle is too close to a square
if abs(height-width) < 8:
return True, None
# reject if the rectangle does not properly fit the contour
if height*width > 2 * area:
return True, None
return False, rect
'''
Estimate the angle of pendulum from its fitted rectangle.
Args:
rect: (Box2D structure) the fitted rectangle
ref_pt: reference point to determine the arrow direction
Returns:
angle: the estimated angle in radians in range (0, 2*pi)
box: box points of the rectangle
arrow: the arrow pointing along the angle
'''
def estimate_angle(rect, ref_pt):
# box points
box = cv2.boxPoints(rect)
# center, width and height
(cx, cy), (width, height), _ = rect
# specify the direction vector of the pendulum
if width < height:
v_d = box[0] - box[1]
else:
v_d = box[0] - box[3]
# reference vector
v_ref = np.array([cx, cy]) - ref_pt
# choose the one from v_d and -v_d whose angle with v_ref
# is less than pi
if np.dot(v_d, v_ref) < 0:
v_d = -v_d
# counterclockwise angle from (0,-1) to v_d
angle = np.arctan2(-v_d[1], v_d[0]) + np.pi/2
if angle < 0:
angle += 2*np.pi
arrow = ((int(cx), int(cy)), (int(cx+v_d[0]), int(cy+v_d[1])))
return angle, box, arrow
'''
Obtain the angles of double pendulum from image
Args:
img: double pendulum image in BGR format
Returns:
rej: (True/False) if the image is rejected
angles: the estimated angles in radians in range (0, 2*pi)
img_marked: image marked with the fitted rectangles,
the direction vectors and the estimated angles (BGR format)
'''
def obtain_angle(img):
img_marked = img.copy()
seg_1, seg_2 = seg_from_img(img)
rej_1, rect_1 = fit_pendulum(seg_1)
rej_2, rect_2 = fit_pendulum(seg_2)
if rej_1:
rej = True
angles = []
cv2.putText(img_marked, 'Reject arm 1', (10, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1)
elif rej_2:
rej = True
angles = []
cv2.putText(img_marked, 'Reject arm 2', (10, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1)
else:
rej = False
img_center = np.array([64, 64])
angle_1, box_1, arrow_1 = estimate_angle(rect_1, img_center)
(cx, cy), _, _ = rect_1
pd_1_end = 2 * np.array([cx, cy]) - img_center
angle_2, box_2, arrow_2 = estimate_angle(rect_2, pd_1_end)
angles = [angle_1, angle_2]
# mark the fitted rectangles
cv2.drawContours(img_marked, [np.int0(box_1)], 0, (0,0,255), 2)
cv2.drawContours(img_marked, [np.int0(box_2)], 0, (0,0,255), 2)
# mark the direction vectors
cv2.arrowedLine(img_marked, arrow_1[0], arrow_1[1], (0, 0, 255), 1, tipLength=0.25)
cv2.arrowedLine(img_marked, arrow_2[0], arrow_2[1], (0, 0, 255), 1, tipLength=0.25)
# mark the estimated angles in degrees
text = str(round(angle_1*180/np.pi)) + ' ' + str(round(angle_2*180/np.pi))
cv2.putText(img_marked, text, (10, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1)
return rej, angles, img_marked
================================================
FILE: analysis/eval_phys_double_pendulum/physics_estimator.py
================================================
'''
This script provides utility functions estimating angular velocities
and other physical quantities from angles of double pendulum.
'''
import numpy as np
from scipy.interpolate import CubicSpline
# physical parameters
fps = 60 # frames per second
L1, L2 = 0.205, 0.179 # pendulum rod lengths (m)
w1, w2 = 0.038, 0.038 # pendulum rod widths (m)
m1, m2 = 0.262, 0.110 # bob masses (kg)
g = 9.81 # gravitational acceleration (m/s^2)
'''
Calculate the absolute difference between two angles th1 and th2 on a circle.
Assumed that the absolute difference between the two angles is within range (0,pi).
'''
def calc_diff(th1, th2):
diff = np.abs(th2 - th1)
diff = np.minimum(diff, 2*np.pi-diff)
return diff
'''
Calculate the average of two angles th1 and th2 on a circle.
Assumed that the absolute difference between the two angles is within range (0,pi).
'''
def calc_avrg(th1, th2):
avrg = (th1 + th2) / 2
diff = np.abs(th2 - th1)
if diff > np.pi:
avrg -= np.pi
if avrg < 0:
avrg += 2*np.pi
return avrg
'''
Calculate angular velocities from a sequence of angles
using numerical differentiation.
method='fd': finite difference;
method='spline': cubic spline fitting.
'''
def calc_velocity(th, method='spline'):
len_seq = th.shape[0]
# isolated data
if len_seq == 1:
return np.nan
# preprocessing: periodic extension of angles
for i in range(1, len_seq):
if th[i] - th[i-1] > np.pi:
th[i:] -= 2*np.pi
elif th[i] - th[i-1] < -np.pi:
th[i:] += 2*np.pi
vel_th = np.zeros(len_seq)
# finite difference
if method == 'fd':
for i in range(1, len_seq):
vel_th[i] = (th[i] - th[i-1]) * fps
vel_th[0] = (th[1] - th[0]) * fps
# cubic spline fitting
elif method == 'spline':
t = np.arange(len_seq) / fps
cs = CubicSpline(t, th)
vel_th = cs(t, 1)
# use finite difference at boundary points to improve accuracy
vel_th[0] = (th[1] - th[0]) * fps
vel_th[-1] = (th[-1] - th[-2]) * fps
else:
assert False, 'Unrecognizable differentiation method!'
return vel_th
'''
Calculate energies from angles and angular velocities
'''
def calc_energy(th1, th2, vel_th1, vel_th2):
# centers of masses in x-y coordinates
x1 = (L1 / 2) * np.sin(th1)
y1 = (-L1 / 2) * np.cos(th1)
x2 = (L1 * np.sin(th1)) + (L2 / 2) * np.sin(th2)
y2 = (-L1 * np.cos(th1)) - (L2 / 2) * np.cos(th2)
# velocities in x-y coordinates
vel_x1 = vel_th1 * (L1 / 2) * np.cos(th1)
vel_y1 = vel_th1 * (L1 / 2) * np.sin(th1)
vel_x2 = vel_th1 * L1 * np.cos(th1) + (vel_th2 / 2) * L2 * np.cos(th2)
vel_y2 = vel_th1 * L1 * np.sin(th1) + (vel_th2 / 2) * L2 * np.sin(th2)
# moments of inertia
I1 = (1. / 12.) * m1 * (w1 ** 2 + L1 ** 2)
I2 = (1. / 12.) * m2 * (w2 ** 2 + L2 ** 2)
# potential energy
V = g * (m1 * y1 + m2 * y2)
# kinetic energy (translation + rotation)
T = 0.5 * m1 * (vel_x1 ** 2 + vel_y1 ** 2) + 0.5 * m2 * (vel_x2 ** 2 + vel_y2 ** 2) + 0.5 * I1 * (
vel_th1 ** 2) + 0.5 * I2 * (vel_th2 ** 2)
# total energy
E = T + V
return T, V, E
================================================
FILE: analysis/eval_phys_elastic_pendulum/__init__.py
================================================
from .angle_estimator import obtain_angle_stretch
from .physics_estimator import *
import numpy as np
phys_vars_list = {'reject', 'theta_1', 'vel_theta_1', 'theta_2', 'vel_theta_2', 'z', 'vel_z',
'kinetic energy', 'potential energy', 'total energy'}
def eval_physics(frames):
num_frames = len(frames)
reject = np.zeros(num_frames, dtype=bool)
theta_1 = np.zeros(num_frames)
theta_2 = np.zeros(num_frames)
z = np.zeros(num_frames)
# estimate angles and stretch
for p in range(num_frames):
reject_p, angles_p, stretch_p, _ = obtain_angle_stretch(frames[p])
if reject_p:
reject[p] = True
theta_1[p] = np.nan
theta_2[p] = np.nan
z[p] = np.nan
else:
reject[p] = False
theta_1[p] = angles_p[0]
theta_2[p] = angles_p[1]
z[p] = stretch_p
# calculate velocities
vel_theta_1 = np.zeros(num_frames)
vel_theta_2 = np.zeros(num_frames)
vel_z = np.zeros(num_frames)
sub_ids = np.ma.clump_unmasked(np.ma.masked_array(theta_1, reject))
for ids in sub_ids:
vel_theta_1[ids] = calc_velocity(theta_1[ids].copy())
vel_theta_2[ids] = calc_velocity(theta_2[ids].copy())
vel_z[ids] = calc_velocity(z[ids].copy(), periodic=False)
vel_theta_1[reject] = np.nan
vel_theta_2[reject] = np.nan
vel_z[reject] = np.nan
# calculate energies
kinetic_energy, potential_energy, total_energy = calc_energy(theta_1, theta_2, z, vel_theta_1, vel_theta_2, vel_z)
# save results
phys = dict.fromkeys(phys_vars_list)
phys['reject'] = reject
phys['theta_1'] = theta_1
phys['vel_theta_1'] = vel_theta_1
phys['theta_2'] = theta_2
phys['vel_theta_2'] = vel_theta_2
phys['z'] = z
phys['vel_z'] = vel_z
phys['kinetic energy'] = kinetic_energy
phys['potential energy'] = potential_energy
phys['total energy'] = total_energy
return phys
================================================
FILE: analysis/eval_phys_elastic_pendulum/angle_estimator.py
================================================
'''
This script provides utility functions estimating angles and stretch of a elastic double pendulum from image.
Step 1. Extract each of the two pendulums from the image.
Step 2. Do a rectangle fitting of each pendulum.
Step 3. Estimate the angles of the pendulums and the stretch of the elastic one.
Certain images will be rejected if any pendulum does not exist, is hidden, or has a wrong shape.
'''
import cv2
import os
import numpy as np
'''
Extract the two pendulums from the image.
Args:
img: elastic double pendulum image in BGR format
Returns:
seg_1: segmentation of the first pendulum
seg_2: segmentation of the second pendulum
'''
def seg_from_img(img):
# pixel thresholds (in HSV)
v_min_1 = (0, 0, 0)
v_max_1 = (255, 255, 140)
v_min_2 = (60, 0, 0)
v_max_2 = (255, 255, 255)
# background color in HSV (RGB: [215, 205, 192])
bg_color = [17, 27, 215]
img_hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
# extract the first pendulum
seg_1 = cv2.inRange(img_hsv, v_min_1, v_max_1)
# remove the first pendulum by replacing it with background pixels
img_hsv[seg_1==255] = bg_color
# extract the second pendulum
seg_2 = cv2.inRange(img_hsv, v_min_2, v_max_2)
return seg_1, seg_2
'''
Fit pendulum in a rectangle.
Args:
seg: segmentation of the pendulum
Returns:
rej: (True/False) if the image is rejected
rect: (Box2D structure) the fitted rectangle
'''
def fit_pendulum(seg):
# find all contours
contours, _ = cv2.findContours(seg, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
# reject if no contours found
if len(contours) == 0:
return True, None
# find the contour with the maximum area
cnt = max(contours, key=cv2.contourArea)
area = cv2.contourArea(cnt)
# reject if the contour is too small
if area < 60:
return True, None
# rectangle fitting
rect = cv2.minAreaRect(cnt)
_, (width, height), _ = rect
# reject if the rectangle is too close to a square
if abs(height-width) < 4:
return True, None
# reject if the rectangle does not properly fit the contour
if height*width > 1.8 * area:
return True, None
return False, rect
'''
Estimate the angle of pendulum from its fitted rectangle.
Args:
rect: (Box2D structure) the fitted rectangle
ref_pt: reference point to determine the arrow direction
Returns:
angle: the estimated angle in radians in range (0, 2*pi)
box: box points of the rectangle
arrow: the arrow pointing along the angle
'''
def estimate_angle(rect, ref_pt):
# box points
box = cv2.boxPoints(rect)
# center, width and height
(cx, cy), (width, height), _ = rect
# specify the direction vector of the pendulum
if width < height:
v_d = box[0] - box[1]
else:
v_d = box[0] - box[3]
# reference vector
v_ref = np.array([cx, cy]) - ref_pt
# choose the one from v_d and -v_d whose angle with v_ref
# is less than pi
if np.dot(v_d, v_ref) < 0:
v_d = -v_d
# counterclockwise angle from (0,-1) to v_d
angle = np.arctan2(-v_d[1], v_d[0]) + np.pi/2
if angle < 0:
angle += 2*np.pi
arrow = ((int(cx), int(cy)), (int(cx+v_d[0]), int(cy+v_d[1])))
return angle, box, arrow
'''
Obtain the angles and stretch of elastic double pendulum from image
Args:
img: elastic double pendulum image in BGR format
Returns:
rej: (True/False) if the image is rejected
angles: the estimated angles in radians in range (0, 2*pi)
stretch: the estimated stretch of the elastic pendulum (unit: m)
img_marked: image marked with the fitted rectangles,
the direction vectors and the estimated angles and stretch (BGR format)
'''
def obtain_angle_stretch(img):
img_marked = img.copy()
seg_1, seg_2 = seg_from_img(img)
rej_1, rect_1 = fit_pendulum(seg_1)
rej_2, rect_2 = fit_pendulum(seg_2)
if rej_1:
rej = True
angles = []
stretch = None
cv2.putText(img_marked, 'Reject arm 1', (10, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1)
elif rej_2:
rej = True
angles = []
stretch = None
cv2.putText(img_marked, 'Reject arm 2', (10, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1)
else:
rej = False
img_center = np.array([64, 64])
angle_1, box_1, arrow_1 = estimate_angle(rect_1, img_center)
(cx, cy), (width, height), _ = rect_1
pd_1_end = 2 * np.array([cx, cy]) - img_center
angle_2, box_2, arrow_2 = estimate_angle(rect_2, pd_1_end)
angles = [angle_1, angle_2]
# compute the stretch of the elastic pendulum (unit: m)
# max(width, height): stretched length of the elastic pendulum in pixels
L_0_img = 24 # unstretched length of the elastic pendulum in pixels
L_0 = 0.205 # unstretched length of the elastic pendulum (m)
stretch = (max(width, height) - L_0_img) / L_0_img * L_0
# mark the fitted rectangles
cv2.drawContours(img_marked, [np.int0(box_1)], 0, (0,0,255), 2)
cv2.drawContours(img_marked, [np.int0(box_2)], 0, (0,0,255), 2)
# mark the direction vectors
cv2.arrowedLine(img_marked, arrow_1[0], arrow_1[1], (0, 0, 255), 1, tipLength=0.25)
cv2.arrowedLine(img_marked, arrow_2[0], arrow_2[1], (0, 0, 255), 1, tipLength=0.25)
# mark the estimated angles (degrees) and stretch (m)
text = str(round(angle_1*180/np.pi)) + ' ' + str(round(angle_2*180/np.pi)) + ' ' + ('%.2f' % stretch)
cv2.putText(img_marked, text, (10, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.45, (0, 0, 0), 1)
return rej, angles, stretch, img_marked
================================================
FILE: analysis/eval_phys_elastic_pendulum/physics_estimator.py
================================================
'''
This script provides utility functions estimating velocities
and other physical quantities of elastic double pendulum.
'''
import numpy as np
from scipy.interpolate import CubicSpline
# physical parameters
fps = 60 # frames per second
L_0 = 0.205 # elastic pendulum unstretched length (m)
L = 0.179 # rigid pendulum rod length (m)
w = 0.038 # rigid pendulum rod width (m)
m = 0.110 # rigid pendulum mass (kg)
g = 9.81 # gravitational acceleration (m/s^2)
k = 40.0 # elastic constant (kg/s^2)
'''
Calculate the absolute difference between two angles th1 and th2 on a circle.
Assumed that the absolute difference between the two angles is within range (0,pi).
'''
def calc_diff(th1, th2):
diff = np.abs(th2 - th1)
diff = np.minimum(diff, 2*np.pi-diff)
return diff
'''
Calculate the average of two angles th1 and th2 on a circle.
Assumed that the absolute difference between the two angles is within range (0,pi).
'''
def calc_avrg(th1, th2):
avrg = (th1 + th2) / 2
diff = np.abs(th2 - th1)
if diff > np.pi:
avrg -= np.pi
if avrg < 0:
avrg += 2*np.pi
return avrg
'''
Normalize the angle to (-pi, pi).
'''
def normalize_angle(theta):
return np.arctan2(np.sin(theta), np.cos(theta))
'''
Calculate velocities from a sequence of data
using numerical differentiation.
method='fd': finite difference;
method='spline': cubic spline fitting.
'''
def calc_velocity(x, method='spline', periodic=True):
len_seq = x.shape[0]
# isolated data
if len_seq == 1:
return np.nan
# preprocessing: periodic extension of angles
if periodic:
for i in range(1, len_seq):
if x[i] - x[i-1] > np.pi:
x[i:] -= 2*np.pi
elif x[i] - x[i-1] < -np.pi:
x[i:] += 2*np.pi
vel_x = np.zeros(len_seq)
# finite difference
if method == 'fd':
for i in range(1, len_seq):
vel_x[i] = (x[i] - x[i-1]) * fps
vel_x[0] = (x[1] - x[0]) * fps
# cubic spline fitting
elif method == 'spline':
t = np.arange(len_seq) / fps
cs = CubicSpline(t, x)
vel_x = cs(t, 1)
# use finite difference at boundary points to improve accuracy
vel_x[0] = (x[1] - x[0]) * fps
vel_x[-1] = (x[-1] - x[-2]) * fps
else:
assert False, 'Unrecognizable differentiation method!'
return vel_x
'''
Calculate energies
'''
def calc_energy(th1, th2, z, vel_th1, vel_th2, vel_z):
# center of mass in x-y coordinates
x = (L_0 + z) * np.sin(th1) + (L / 2) * np.sin(th2)
y = -(L_0 + z) * np.cos(th1) - (L / 2) * np.cos(th2)
# velocities in x-y coordinates
vel_x = vel_th1 * (L_0 + z) * np.cos(th1) + vel_th2 * (L / 2) * np.cos(th2) + vel_z * np.sin(th1)
vel_y = vel_th1 * (L_0 + z) * np.sin(th1) + vel_th2 * (L / 2) * np.sin(th2) - vel_z * np.cos(th1)
# moment of inertia
I = (1. / 12.) * m * (w ** 2 + L ** 2)
# potential energy (gravitational + elastic)
V = m * g * y + 0.5 * k * z**2
# kinetic energy (translation + rotation)
T = 0.5 * m * (vel_x ** 2 + vel_y ** 2) + 0.5 * I * vel_th2 ** 2
# total energy
E = T + V
return T, V, E
================================================
FILE: analysis/eval_phys_long_term_pred.py
================================================
"""
Evaluate long-term prediction accuracy in physical variables.
"""
import os
import sys
import numpy as np
from scipy import stats
import cv2
import json
import yaml
import pprint
from tqdm import tqdm
from munch import munchify
'''
Calculate the absolute difference between two angles th1 and th2 on a circle.
Assumed that the absolute difference between the two angles is within range (0,pi).
'''
def calc_diff(th1, th2):
diff = np.abs(th2 - th1)
diff = np.minimum(diff, 2*np.pi-diff)
return diff
'''
Calculate the pixel mean square error between two images.
'''
def calc_pixel_MSE(img1, img2):
return np.mean((img1/255.0 - img2/255.0) ** 2)
def load_config(filepath):
with open(filepath, 'r') as stream:
try:
trainer_params = yaml.safe_load(stream)
return trainer_params
except yaml.YAMLError as exc:
print(exc)
def eval_pred_physics(data_filepath, pred_save_path, vid_ids, phys_vars_list, eval_physics):
phys = {p_var:[] for p_var in phys_vars_list}
for idx in tqdm(vid_ids):
data_vid_filepath = os.path.join(data_filepath, str(idx))
pred_vid_filepath = os.path.join(pred_save_path, str(idx))
num_frames = len(os.listdir(data_vid_filepath))
num_pred_frames = len(os.listdir(pred_vid_filepath))
frames = []
for p in range(num_frames):
if (num_pred_frames == num_frames - 2) and (p == 0 or p == 1):
# read the initial two frames from ground truth data
frame_p = cv2.imread(os.path.join(data_vid_filepath, str(p)+'.png'))
else:
frame_p = cv2.imread(os.path.join(pred_vid_filepath, str(p)+'.png'))
frames.append(frame_p)
phys_tmp = eval_physics(frames)
for p_var in phys_vars_list:
phys[p_var].append(phys_tmp[p_var])
for p_var in phys_vars_list:
phys[p_var] = np.array(phys[p_var])
return phys
def load_data_physics(data_filepath, vid_ids, phys_vars_list):
phys = np.load(os.path.join(data_filepath, 'phys_vars.npy'), allow_pickle=True).item()
for p_var in phys_vars_list:
phys[p_var] = phys[p_var][vid_ids]
return phys
def eval_physics_error(phys_pred, phys_data, phys_vars_list):
phys_vars_list_2 = [p_var for p_var in phys_vars_list if p_var!='reject']
phys_error = {}
phys_error['reject'] = phys_pred['reject'].copy()
if 'reject' in phys_data.keys():
phys_error['reject_data'] = phys_data['reject'].copy()
else:
phys_error['reject_data'] = np.zeros(phys_pred['reject'].shape)
for p_var in phys_vars_list_2:
if p_var in ['theta', 'theta_1', 'theta_2']:
phys_error[p_var] = calc_diff(phys_pred[p_var], phys_data[p_var])
else:
phys_error[p_var] = np.abs(phys_pred[p_var] - phys_data[p_var])
return phys_error
def eval_pixel_error(data_filepath, pred_save_path, vid_ids):
pixel_error = []
for idx in tqdm(vid_ids):
data_vid_filepath = os.path.join(data_filepath, str(idx))
pred_vid_filepath = os.path.join(pred_save_path, str(idx))
num_frames = len(os.listdir(data_vid_filepath))
pixel_error_idx = []
for p in range(num_frames):
if p == 0 or p == 1:
pixel_error_idx.append(0)
else:
data = cv2.imread(os.path.join(data_vid_filepath, str(p)+'.png'))
pred = cv2.imread(os.path.join(pred_vid_filepath, str(p)+'.png'))
pixel_error_idx.append(calc_pixel_MSE(data, pred))
pixel_error.append(pixel_error_idx)
return np.array(pixel_error)
def main():
config_filepath = str(sys.argv[1])
pred_save_path = str(sys.argv[2])
cfg = load_config(filepath=config_filepath)
pprint.pprint(cfg)
cfg = munchify(cfg)
data_filepath = os.path.join(cfg.data_filepath, cfg.dataset)
with open(os.path.join('../datainfo', cfg.dataset, f'data_split_dict_{cfg.seed}.json'), 'r') as file:
seq_dict = json.load(file)
vid_ids = seq_dict['test']
if cfg.dataset == 'single_pendulum':
from eval_phys_single_pendulum import phys_vars_list, eval_physics
elif cfg.dataset == 'double_pendulum':
from eval_phys_double_pendulum import phys_vars_list, eval_physics
elif cfg.dataset == 'elastic_pendulum':
from eval_phys_elastic_pendulum import phys_vars_list, eval_physics
else:
assert False, 'Unknown system...'
phys_pred = eval_pred_physics(data_filepath, pred_save_path, vid_ids, phys_vars_list, eval_physics)
phys_data = load_data_physics(data_filepath, vid_ids, phys_vars_list)
phys_error = eval_physics_error(phys_pred, phys_data, phys_vars_list)
pixel_error = eval_pixel_error(data_filepath, pred_save_path, vid_ids)
np.save(os.path.join(pred_save_path, 'phys_vars.npy'), phys_pred)
np.save(os.path.join(pred_save_path, 'phys_error.npy'), phys_error)
np.save(os.path.join(pred_save_path, 'pixel_error.npy'), pixel_error)
if __name__ == '__main__':
main()
================================================
FILE: analysis/eval_phys_single_pendulum/__init__.py
================================================
from .angle_estimator import obtain_angle
from .physics_estimator import *
import numpy as np
phys_vars_list = {'reject', 'theta', 'vel_theta', 'kinetic energy',
'potential energy', 'total energy'}
def eval_physics(frames):
num_frames = len(frames)
reject = np.zeros(num_frames, dtype=bool)
theta = np.zeros(num_frames)
# estimate angles
for p in range(num_frames):
reject_p, theta_p, _ = obtain_angle(frames[p])
if reject_p:
reject[p] = True
theta[p] = np.nan
else:
reject[p] = False
theta[p] = theta_p
# calculate velocities
vel_theta = np.zeros(num_frames)
sub_ids = np.ma.clump_unmasked(np.ma.masked_array(theta, reject))
for ids in sub_ids:
vel_theta[ids] = calc_velocity(theta[ids].copy())
vel_theta[reject] = np.nan
# calculate energies
kinetic_energy, potential_energy, total_energy = calc_energy(theta, vel_theta)
# save results
phys = dict.fromkeys(phys_vars_list)
phys['reject'] = reject
phys['theta'] = theta
phys['vel_theta'] = vel_theta
phys['kinetic energy'] = kinetic_energy
phys['potential energy'] = potential_energy
phys['total energy'] = total_energy
return phys
================================================
FILE: analysis/eval_phys_single_pendulum/angle_estimator.py
================================================
'''
This script provides utility functions estimating the angle of a single pendulum from image.
Step 1. Extract the pendulum from the image.
Step 2. Do a rectangle fitting of the pendulum.
Step 3. Estimate the angle of the pendulum.
Certain images will be rejected if the pendulum does not exist or has a wrong shape.
'''
import cv2
import os
import numpy as np
'''
Extract the pendulum from the image.
Args:
img: single pendulum image in BGR format
Returns:
seg: segmentation of the pendulum
'''
def seg_from_img(img):
# pixel thresholds (in HSV)
v_min = (60, 0, 0)
v_max = (255, 255, 255)
# extract the pendulum
img_hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
seg = cv2.inRange(img_hsv, v_min, v_max)
return seg
'''
Fit pendulum in a rectangle.
Args:
seg: segmentation of the pendulum
Returns:
rej: (True/False) if the image is rejected
rect: (Box2D structure) the fitted rectangle
'''
def fit_pendulum(seg):
# find all contours
contours, _ = cv2.findContours(seg, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
# reject if no contours found
if len(contours) == 0:
return True, None
# find the contour with the maximum area
cnt = max(contours, key=cv2.contourArea)
area = cv2.contourArea(cnt)
# reject if the contour is too small
if area < 400:
return True, None
# rectangle fitting
rect = cv2.minAreaRect(cnt)
_, (width, height), _ = rect
# reject if the rectangle is too close to a square
if abs(height-width) < 35:
return True, None
# reject if the rectangle does not properly fit the contour
if height*width > 1.5 * area:
return True, None
return False, rect
'''
Estimate the angle of pendulum from its fitted rectangle.
Args:
rect: (Box2D structure) the fitted rectangle
Returns:
angle: the estimated angle in radians in range (0, 2*pi)
box: box points of the rectangle
arrow: the arrow pointing along the angle
'''
def estimate_angle(rect):
# box points
box = cv2.boxPoints(rect)
# center, width and height
(cx, cy), (width, height), _ = rect
# specify the direction vector of the pendulum
if width < height:
v_d = box[0] - box[1]
else:
v_d = box[0] - box[3]
# reference vector
v_ref = np.array([cx-64, cy-64])
# choose the one from v_d and -v_d whose angle with v_ref
# is less than pi
if np.dot(v_d, v_ref) < 0:
v_d = -v_d
# counterclockwise angle from (0,-1) to v_d
angle = np.arctan2(-v_d[1], v_d[0]) + np.pi/2
if angle < 0:
angle += 2*np.pi
arrow = ((int(cx), int(cy)), (int(cx+0.7*v_d[0]), int(cy+0.7*v_d[1])))
return angle, box, arrow
'''
Obtain the angle of single pendulum from image
Args:
img: single pendulum image in BGR format
Returns:
rej: (True/False) if the image is rejected
angle: the estimated angle in radians in range (0, 2*pi)
img_marked: image marked with the fitted rectangle,
the direction vector and the estimated angle (BGR format)
'''
def obtain_angle(img):
img_marked = img.copy()
seg = seg_from_img(img)
rej, rect = fit_pendulum(seg)
if not rej:
angle, box, arrow = estimate_angle(rect)
# mark the fitted rectangle
cv2.drawContours(img_marked, [np.int0(box)], 0, (0,0,255), 2)
# mark the direction vector
cv2.arrowedLine(img_marked, arrow[0], arrow[1], (0, 0, 255), 1, tipLength=0.25)
# mark the estimated angle in degrees
cv2.putText(img_marked, str(round(angle*180/np.pi)), (10, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1)
else:
# mark the rejection
cv2.putText(img_marked, 'Reject', (10, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1)
angle = np.nan
return rej, angle, img_marked
================================================
FILE: analysis/eval_phys_single_pendulum/physics_estimator.py
================================================
'''
This script provides utility functions estimating angular velocities
and other physical quantities from angles of single pendulum.
'''
import numpy as np
from scipy.interpolate import CubicSpline
# physical parameters
fps = 60 # frames per second
l = 0.5 # pendulum rod length (m)
m = 1.0 # bob mass (kg)
g = 9.81 # gravitational acceleration (m/s^2)
'''
Calculate the absolute difference between two angles th1 and th2 on a circle.
Assumed that the absolute difference between the two angles is within range (0,pi).
'''
def calc_diff(th1, th2):
diff = np.abs(th2 - th1)
diff = np.minimum(diff, 2*np.pi-diff)
return diff
'''
Calculate the average of two angles th1 and th2 on a circle.
Assumed that the absolute difference between the two angles is within range (0,pi).
'''
def calc_avrg(th1, th2):
avrg = (th1 + th2) / 2
diff = np.abs(th2 - th1)
if diff > np.pi:
avrg -= np.pi
if avrg < 0:
avrg += 2*np.pi
return avrg
'''
Normalize the angle to (-pi, pi).
'''
def normalize_angle(theta):
return np.arctan2(np.sin(theta), np.cos(theta))
'''
Calculate angular velocities from a sequence of angles
using numerical differentiation.
method='fd': finite difference;
method='spline': cubic spline fitting.
'''
def calc_velocity(th, method='spline'):
len_seq = th.shape[0]
# isolated data
if len_seq == 1:
return np.nan
# preprocessing: periodic extension of angles
for i in range(1, len_seq):
if th[i] - th[i-1] > np.pi:
th[i:] -= 2*np.pi
elif th[i] - th[i-1] < -np.pi:
th[i:] += 2*np.pi
vel_th = np.zeros(len_seq)
# finite difference
if method == 'fd':
for i in range(1, len_seq):
vel_th[i] = (th[i] - th[i-1]) * fps
vel_th[0] = (th[1] - th[0]) * fps
# cubic spline fitting
elif method == 'spline':
t = np.arange(len_seq) / fps
cs = CubicSpline(t, th)
vel_th = cs(t, 1)
# use finite difference at boundary points to improve accuracy
vel_th[0] = (th[1] - th[0]) * fps
vel_th[-1] = (th[-1] - th[-2]) * fps
else:
assert False, 'Unrecognizable differentiation method!'
return vel_th
'''
Calculate energies from angles and angular velocities
Moment of inertia of the pendulum: I=ml^2/3
'''
def calc_energy(th, vel_th):
T = m * l**2 / 6 * vel_th**2
V = - m * g * l / 2 * np.cos(th)
E = T + V
return T, V, E
================================================
FILE: analysis/eval_regression.py
================================================
import os
import sys
import numpy as np
from tqdm import tqdm
import json
import yaml
import pprint
from munch import munchify
from latent_regression import pca, mlp_regress
def load_config(filepath):
with open(filepath, 'r') as stream:
try:
trainer_params = yaml.safe_load(stream)
return trainer_params
except yaml.YAMLError as exc:
print(exc)
# parse the video no. and frame no. from the data id (example : '7_0.png')
def parse_data_id(data_id):
lhs, rhs = data_id.split('_')
vid_n = int(lhs)
frm_n = int(rhs.split('.')[0])
return vid_n, frm_n
def physical_variables_from_data_ids(phys_all, ids):
phys_vars_list = []
for p_var in phys_all.keys():
if p_var == 'reject':
continue
for t in range(4):
phys_vars_list.append(f'{p_var} (t={t})')
num_data = ids.shape[0]
phys = {p_var:np.zeros(num_data) for p_var in phys_vars_list}
for n in range(num_data):
vid_n, frm_n = parse_data_id(ids[n])
for p_var in phys_all.keys():
if p_var == 'reject':
continue
for t in range(4):
phys[f'{p_var} (t={t})'][n] = phys_all[p_var][vid_n, frm_n+t]
return phys
def main():
config_filepath = str(sys.argv[1])
cfg = load_config(filepath=config_filepath)
pprint.pprint(cfg)
cfg = munchify(cfg)
if_refine = ('refine' in cfg.model_name)
phys_all = np.load(os.path.join(cfg.data_filepath, cfg.dataset, 'phys_vars.npy'), allow_pickle=True).item()
log_dir = '_'.join([cfg.log_dir, cfg.dataset, cfg.model_name, str(cfg.seed)])
# latent vectors and physical variables (train and test)
ids_train = np.load(os.path.join(log_dir, 'variables_train', 'ids.npy'))
ids_test = np.load(os.path.join(log_dir, 'variables', 'ids.npy'))
if if_refine:
latent_train = np.load(os.path.join(log_dir, 'variables_train', 'refine_latent.npy'))
latent_test = np.load(os.path.join(log_dir, 'variables', 'refine_latent.npy'))
else:
latent_train = np.load(os.path.join(log_dir, 'variables_train', 'latent.npy'))
latent_test = np.load(os.path.join(log_dir, 'variables', 'latent.npy'))
phys_train = physical_variables_from_data_ids(phys_all, ids_train)
phys_test = physical_variables_from_data_ids(phys_all, ids_test)
phys_vars_list = phys_train.keys()
# remove outliers
is_outlier_train = np.full(latent_train.shape[0], False)
for p_var in phys_vars_list:
is_outlier_train = is_outlier_train | np.isnan(phys_train[p_var])
for p_var in phys_vars_list:
phys_train[p_var] = phys_train[p_var][~is_outlier_train]
latent_train = latent_train[~is_outlier_train]
is_outlier_test = np.full(latent_test.shape[0], False)
for p_var in phys_vars_list:
is_outlier_test = is_outlier_test | np.isnan(phys_test[p_var])
for p_var in phys_vars_list:
phys_test[p_var] = phys_test[p_var][~is_outlier_test]
latent_test = latent_test[~is_outlier_test]
print(f'Valid training samples: {latent_train.shape[0]}; test samples: {latent_test.shape[0]}.')
# only use first few principal components (optional)
if str(sys.argv[2]) == 'NA':
num_components = None
else:
num_components = int(sys.argv[2])
latent_train, latent_test, pca_model = pca(latent_train, latent_test, num_components, cfg.seed)
num_samples = 3
sample_size = int(0.3 * latent_train.shape[0])
train_error = {p_var:np.zeros(num_samples) for p_var in phys_vars_list}
test_error = {p_var:np.zeros(num_samples) for p_var in phys_vars_list}
reg_models = {p_var:np.empty(num_samples, dtype=object) for p_var in phys_vars_list}
rng = np.random.default_rng(cfg.seed)
for i in range(num_samples):
print('random samples #'+str(i+1))
samp_ids = rng.choice(latent_train.shape[0], sample_size, replace=False)
for p_var in phys_vars_list:
print('doing '+p_var+' regression...')
train_error[p_var][i], test_error[p_var][i], reg_models[p_var][i] = mlp_regress(latent_train[samp_ids], latent_test, phys_train[p_var][samp_ids], phys_test[p_var], cfg.seed)
if num_components == None:
np.save(os.path.join(log_dir, 'variables', 'regression_results.npy'), [train_error, test_error, reg_models])
else:
np.save(os.path.join(log_dir, 'variables', 'regression_results_pca_'+str(num_components)+'.npy'), [train_error, test_error, pca_model, reg_models])
if __name__ =='__main__':
main()
================================================
FILE: analysis/intrinsic_dimension_estimation/__init__.py
================================================
import numpy as np
from .methods import *
class ID_Estimator:
def __init__(self, method='Levina_Bickel'):
self.all_methods = ['Levina_Bickel', 'MiND_ML', 'MiND_KL', 'Hein', 'CD']
self.set_method(method)
def set_method(self, method='Levina_Bickel'):
if method not in self.all_methods:
assert False, 'Unknown method!'
else:
self.method = method
def fit(self, X, k_list=20, n_jobs=4):
if self.method in ['Hein', 'CD']:
dim_Hein, dim_CD = Hein_CD(X)
return dim_Hein if self.method=='Hein' else dim_CD
else:
if np.isscalar(k_list):
k_list = np.array([k_list])
else:
k_list = np.array(k_list)
kmax = np.max(k_list) + 2
dists, inds = kNN(X, kmax, n_jobs)
dims = []
for k in k_list:
if self.method == 'Levina_Bickel':
dims.append(Levina_Bickel(X, dists, k))
elif self.method == 'MiND_ML':
dims.append(MiND_ML(X, dists, k))
elif self.method == 'MiND_KL':
dims.append(MiND_KL(X, dists, k))
else:
pass
if len(dims) == 1:
return dims[0]
else:
return np.array(dims)
def fit_all_methods(self, X, k_list=[20], n_jobs=4):
k_list = np.array(k_list)
kmax = np.max(k_list) + 2
dists, inds = kNN(X, kmax, n_jobs)
dim_all_methods = {method:[] for method in self.all_methods}
dim_all_methods['Hein'], dim_all_methods['CD'] = Hein_CD(X)
for k in k_list:
dim_all_methods['Levina_Bickel'].append(Levina_Bickel(X, dists, k))
dim_all_methods['MiND_ML'].append(MiND_ML(X, dists, k))
dim_all_methods['MiND_KL'].append(MiND_KL(X, dists, k))
for method in self.all_methods:
dim_all_methods[method] = np.array(dim_all_methods[method])
return dim_all_methods
================================================
FILE: analysis/intrinsic_dimension_estimation/matlab_codes/DANCo.m
================================================
function [d,kl] = DANCo(data,varargin)
% DANCO Estmating the intrinsic dimensionality of a dataset
%
% function [d,kl] = DANCo(data, ParamName,ParamValue, ...)
%
% This funciton estimates the intrinsic dimensionality of a given dataset
% by means of the DANCo algorithm described in:
%
% "DANCo: Dimensionality from Angle and Norm Concentration"
% C. Ceruti, S. Bassis, A. Rozza, G. Lombardi, E. Casiraghi, P. Campadelli
% arXiv:1206.3881v1, http://arxiv.org/abs/1206.3881v1
%
% with the complexity improvement given by separating the training step
% from the evaluation one (requires fitting models in the file DANCo_fits).
%
% Parameters
% ----------
% IN:
% data = The data matrix (DxN) with a sample per column.
% Named parameters:
% 'k'
% The number of neighbours to be used, must be >= 1. (def=10)
% 'minDim'
% The minimal number of dimensions. (def=1)
% 'maxDim'
% The maximal number of dimensions. (def=size(data,1))
% 'sigma'
% Standard deviation of the gaussian smoothing filter executed on
% the estimated divergences. (def=[]=no smoothing)
% 'fractal'
% Is the fractal dimension required? (def=false)
% 'mlMethod'
% Algorithm used to pre-estimate d:
% MLE = Levina's algorithm.
% MiND_MLI = ML on the first neighbour.
% MiND_MLK = ML on the first neighbour and optimization.
% (def='MiND_MLK')
% 'minCorrectionDim'
% Minimal dimensionality estimated with the selected mlMethod such
% that DANCO is adopted for correction, for lower dimensionalities
% the mlMethod result is produced. (def=5)
% 'inds'
% Precomputed knn indices as returned by the KNN function for k
% set to k+2 with respect to the given k parameter. (also dists must
% be present)
% 'dists'
% Precomputed knn distances as returned by the KNN function for k
% set to k+2 with respect to the given k parameter. (also inds must
% be present)
% 'normalized'
% Are the inds/dists already normalized? Normalized here means that
% all the k+2 distances are normalized wrt the last one and the first
% (zero) and the last (one) are removed. (def=true)
% OUT:
% d = Estimated intrinsic dimensionality of the dataset.
% kl = The Kullback Leibler divergences curve for d=1..D.
%
% Examples
% --------
% X = randsphere(30,1000,1);
% V = linSubspSpanOrthonormalize(randn(120,30));
% pts = V*X;
% [inds,dists] = KNN(pts,12,true);
% [d,kl] = DANCo(pts,'inds',inds,'dists',dists,'minDim',20,'maxDim',40);
% Checking parameters:
if nargin<1; error('A dataset is required'); end
% Infos:
[D,N] = size(data);
% Default parameters:
params = struct('k',{10}, ...
'minDim',{1}, 'maxDim',{D}, ...
'fractal',{false}, ...
'sigma',{[]},...
'mlMethod',{'MiND_MLK'}, ...
'normalized',{true}, ...
'minCorrectionDim',{5});
% Parsing named parameters:
if nargin > 1
try
params = parseParamsNamed(varargin, false, params);
catch e
error('Pairs ''name,value'' are required as arguments after the dataset.');
end
end
% Estimating the parameters for the real data:
[isLowDim,dHat,mu,tau] = DANCo_statistics(data,params.k,params,nargout > 1);
% Correction based on the KL-divergence:
if not(isLowDim) || nargout > 1
% Initializing the parameters:
params2 = struct( ...
'mlMethod',{params.mlMethod}, ...
'minCorrectionDim',{params.minCorrectionDim});
numDims = params.maxDim - params.minDim + 1;
% Computing the reference parameters:
dHat_ref = zeros(1,numDims);
mu_ref = zeros(1,numDims);
tau_ref = zeros(1,numDims);
for i = 1:numDims
% Generating the dataset:
data2 = randsphere(params.minDim + i - 1, N);
% Estimating the parameters:
[~,dHat_ref(i),mu_ref(i),tau_ref(i)] = ...
DANCo_statistics(data2,params.k,params2,true);
end
% Estimating the KL-divergence for all the cases d=1:D:
kl = DANCo_estimateKL(params.k,dHat,mu,tau,dHat_ref,mu_ref,tau_ref);
end
% Smoothing if required:
if (not(isLowDim) || nargout > 1) && not(isempty(params.sigma))
% Preparing the impulse response:
rad = 4*ceil(params.sigma/2);
h = gauss(-rad:rad,params.sigma,0,false);
h = h/sum(h);
% Filtering:
kl = conv([flipdim(kl,2),kl,kl],h,'same');
kl = kl(numDims+1:end-numDims);
end
% The dimensionality estimation:
if isLowDim
% Using the MLE estimation:
if params.fractal; d = dHat;
else d = round(dHat); end
else
% Selecting the minimum kl position:
[~,d] = min(kl);
% Refining for the fractal dimension:
if params.fractal
% Fitting with a cubic smoothing spline:
fn = csapi(1:numDims,kl);
% Locating the minima:
opts = optimset('Display','off','TolX',1e-3);
d = fminsearch(@(x)(fnval(fn,x)),d,opts);
end
end
% Correcting the dimension considering the minimal one:
d = d + params.minDim - 1;
end
================================================
FILE: analysis/intrinsic_dimension_estimation/matlab_codes/DANCoFit.m
================================================
function [d,kl] = DANCoFit(data,varargin)
% DANCOFIT Estmating the intrinsic dimensionality of a dataset
%
% function [d,kl] = DANCoFit(data, ParamName,ParamValue, ...)
%
% This funciton estimates the intrinsic dimensionality of a given dataset
% by means of the fast DANCo algorithm described in:
%
% "DANCo: Dimensionality from Angle and Norm Concentration"
% C. Ceruti, S. Bassis, A. Rozza, G. Lombardi, E. Casiraghi, P. Campadelli
% arXiv:1206.3881v1, http://arxiv.org/abs/1206.3881v1
%
% with the complexity improvement given by separating the training step
% from the evaluation one (requires fitting models in the file DANCo_fits).
%
% Parameters
% ----------
% IN:
% data = The data matrix (DxN) with a sample per column.
% Named parameters:
% 'fractal'
% Is the fractal dimension required? (def=false)
% 'mlMethod'
% Algorithm used to pre-estimate d:
% MLE = Levina's algorithm.
% MiND_MLI = ML on the first neighbour.
% MiND_MLK = ML on the first neighbour and optimization.
% (def='MiND_MLK')
% 'minCorrectionDim'
% Minimal dimensionality estimated with the selected mlMethod such
% that DANCO is adopted for correction, for lower dimensionalities
% the mlMethod result is produced. (def=5)
% 'inds'
% Precomputed knn indices as returned by the KNN function for k
% set to k+2 with respect to the given k parameter. (also dists must
% be present)
% 'dists'
% Precomputed knn distances as returned by the KNN function for k
% set to k+2 with respect to the given k parameter. (also inds must
% be present)
% 'normalized'
% Are the inds/dists already normalized? Normalized here means that
% all the k+2 distances are normalized wrt the last one and the first
% (zero) and the last (one) are removed. (def=true)
% 'modelfile'
% The name of the model file to be used. (def='DANCo_fits')
% OUT:
% d = Estimated intrinsic dimensionality of the dataset.
% kl = The Kullback Leibler divergences curve for d=1..D.
%
% Examples
% --------
% X = randsphere(30,1000,1);
% V = linSubspSpanOrthonormalize(randn(120,30));
% pts = V*X;
% [inds,dists] = KNN(pts,12,true);
% [d,kl] = DANCoFit(pts,'inds',inds,'dists',dists);
% Checking parameters:
if nargin<1; error('A dataset is required'); end
% Infos:
[D,N] = size(data);
d = 1:D;
% Default parameters:
params = struct('fractal',{false}, ...
'mlMethod',{'MiND_MLK'}, ...
'normalized',{true}, ...
'minCorrectionDim',{5}, ...
'modelfile',{'DANCo_fits'});
% Parsing named parameters:
if nargin > 1
try
params = parseParamsNamed(varargin, false, params);
catch e
error('Pairs ''name,value'' are required as arguments after the dataset.');
end
end
% Loading the model:
load(params.modelfile);
% Estimating the parameters:
[isLowDim,dHat,mu,tau] = DANCo_statistics(data,k,params,nargout > 1);
% Correction based on the KL-divergence:
if not(isLowDim) || nargout > 1
% Computing the reference parameters:
dHat_ref = feval(fitDhat,d,N);
mu_ref = feval(fitMu,d,N);
tau_ref = feval(fitTau,d,N);
% Estimating the KL-divergence for all the cases d=1:D:
kl = DANCo_estimateKL(k,dHat,mu,tau,dHat_ref,mu_ref,tau_ref);
end
% The dimensionality estimation:
if isLowDim
% Using the MLE estimation:
if params.fractal; d = dHat;
else d = round(dHat); end
else
% Selecting the minimum kl position:
[~,d] = min(kl);
% Refining for the fractal dimension:
if params.fractal
% Fitting with a cubic smoothing spline:
fn = csapi(1:D,kl);
% Locating the minima:
opts = optimset('Display','off','TolX',1e-3);
d = fminsearch(@(x)(fnval(fn,x)),d,opts);
end
end
end
================================================
FILE: analysis/intrinsic_dimension_estimation/matlab_codes/DANCoTrain.m
================================================
function DANCoTrain(modelfile,varargin)
% DANCoTrain Training the DANCoFit model file.
%
% function DANCoTrain(modelfile, ParamName,ParamValue, ...)
%
% This funciton executes experiments with points uniformly drawn from
% hyper-spheres with unitary radius to produce the fitting functions used
% by DANCoFit, if you use this please cite:
%
% "DANCo: Dimensionality from Angle and Norm Concentration"
% C. Ceruti, S. Bassis, A. Rozza, G. Lombardi, E. Casiraghi, P. Campadelli
% arXiv:1206.3881v1, http://arxiv.org/abs/1206.3881v1
%
% Parameters
% ----------
% IN:
% modelfile = The name of the model file to be produced.
% Named parameters:
% 'k'
% The number of neighbours to be used, must be >= 1. (def=10)
% 'maxDim'
% The maximal dimensionality to be used in the experiments, more
% dims means more computation but also accurate fits. (def=100)
% 'cardinalities'
% The list of cardinalities to be used in the experiments, more
% cardinalities means more computation but also accurate fits, the
% suggested choice is to adopt an 1-2-5 series.
% (def=[50,100,200,500,1000,2000,5000])
% 'iterations'
% The number of iterations for experiment (results averaged), more
% iterations means more computation but also accurate fits. (def=100)
% 'mlMethod'
% Algorithm used to pre-estimate d:
% MLE = Levina's algorithm.
% MiND_MLI = ML on the first neighbour.
% MiND_MLK = ML on the first neighbour and optimization.
% (def='MiND_MLK')
%
% Example
% -------
% DANCoTrain('sampleModel','maxDim',20,'cardinalities',[200,500,1000],'iterations',1);
% X = randsphere(10,500);
% V = linSubspSpanOrthonormalize(randn(20,10));
% pts = V*X;
% DANCoFit(pts,'modelfile','sampleModel')
% Checking parameters:
if nargin<1; error('A model file name is required.'); end
% Default parameters:
params = struct('k',{10}, ...
'maxDim',{100}, ...
'cardinalities',{[50,100,200,500,1000,2000,5000]}, ...
'iterations',{100}, ...
'mlMethod',{'MiND_MLK'});
% Parsing named parameters:
if nargin > 1
try
params = parseParamsNamed(varargin, false, params);
catch e
error('Pairs ''name,value'' are required as arguments after the dataset.');
end
end
% Parameters to be passed to the statistics computation function:
params2 = struct( ...
'mlMethod',{params.mlMethod}, ...
'minCorrectionDim',{1});
k = params.k;
% Computing the size of the statistic matrices:
N = numel(params.cardinalities);
D = params.maxDim;
% Initialization:
dHat = zeros(D-1,N);
mu = zeros(D-1,N);
tau = zeros(D-1,N);
for d = 2:D
for n = 1:N
% Initializing the data of this iteration:
dHat_ref = zeros(1,params.iterations);
mu_ref = zeros(1,params.iterations);
tau_ref = zeros(1,params.iterations);
% Executing the experiments:
for i = 1:params.iterations
% Generating the dataset:
data = randsphere(d, params.cardinalities(n));
% Sayng to the user what I'm doing:
fprintf('(d = %d \tn = %d \ti = %d)...', ...
d,params.cardinalities(n),i);
% Estimating the parameters:
[~,dHat_ref(i),mu_ref(i),tau_ref(i)] = ...
DANCo_statistics(data,k,params2,true);
% Sayng to the user what I'm doing:
fprintf(' \tDONE!\n');
end
% Storing the results:
dHat(d-1,n) = mean(dHat_ref);
mu(d-1,n) = mean(mu_ref);
tau(d-1,n) = mean(tau_ref);
end
end
% Sayng to the user what I'm doing:
fprintf('Fitting & saving...');
% Initializing:
inVars = {2:D,params.cardinalities};
% Fitting the funcitons:
fitDhatFun = fnxtr(csaps(inVars,dHat),2);
fitMuFun = fnxtr(csaps(inVars,mu),2);
fitTauFun = fnxtr(csaps(inVars,tau),2);
% Producing function handles:
fitDhat = @(D,N)(fnval(fitDhatFun,{D,N})');
fitMu = @(D,N)(fnval(fitMuFun,{D,N})');
fitTau = @(D,N)(fnval(fitTauFun,{D,N})');
% Writing the model:
save(modelfile,'fitMu','fitTau','fitDhat','k');
% Sayng to the user what I'm doing:
fprintf(' \tDONE!\n');
end
================================================
FILE: analysis/intrinsic_dimension_estimation/matlab_codes/KNN.m
================================================
function [inds,dists] = KNN(X,k,normalized)
%KNN A multi-version knnsearch
%
% function [inds,dists] = KNN(X,k)
%
% This function allows to identify the indexes (and optionally the
% distances) of the first k neighbours of each given columnwise point.
%
% IN:
% X = The dataset, each column represents a point.
% k = The number of neighbours of interest. (def=10)
% normalized = Must the k neighbours be normalized?
% (triggers a k+2 search, def=false)
% OUT:
% inds = Each row i-th contains the neighbour indexes of the i-th point.
% dists = Each row i-th contains the neighbour distances of the i-th point.
% Checking parameters:
if nargin<1; error('Dataset required.'); end
if nargin<2; k=10; end
if nargin<3; normalized=false; end
% Managing th normalized:
if normalized; k=k+2; end
% Checking the version:
if datenum(version('-date')) < 734729
% Computing all the distances:
dists = L2(X, X);
% Sorting:
[dists, inds] = sort(dists, 2);
% Choosing the first k elements:
dists = dists(:,1:k);
inds = inds(:,1:k);
else
% The Matlab implementation:
X = X';
[inds, dists] = knnsearch(X,X,'K',k);
end
% Managing th normalized:
if normalized
% Cropping both inds and dists:
inds = inds(:,2:end-1);
dists = dists(:,2:end);
% Normalizing and removing the last element:
dists = dists(:,1:end-1)./repmat(dists(:,end),[1,size(dists,2)-1]);
end
end
% ------------------------ LOCAL FUNCTIONS ------------------------
% L2 distance (by Roland Bunschoten):
function d = L2(a,b)
d = real(sqrt(bsxfun(@plus, sum(a .* a)', bsxfun(@minus, sum(b .* b), 2 * a' * b))));
end
================================================
FILE: analysis/intrinsic_dimension_estimation/matlab_codes/MLE.m
================================================
function [d,ds] = MLE(data,varargin)
%MLE Intrinsic dimensionality using the Levina's MLE technique
%
% function [d,ds] = MLE(data, ParamName,ParamValue, ...)
%
% This funciton estimates the intrinsic dimensionality of a dataset using
% the Levina's Maximum Likelihood Estimator (MLE).
%
% Parameters
% ----------
% IN:
% data = Matrix DxN containing the dataset as columnwise points.
% Named parameters:
% 'k'
% The number of neighbours to be used, must be >= 1. (def=10)
% 'dists'
% Precomputed knn distanced as returned by the KNN function for k
% set to k+2 with respect to the given k parameter (k is set to
% "size(dists,2)-2" if knn is present).
% 'normalized'
% Are the dists already normalized? Normalized here means that all
% the k+2 distances are normalized wrt the last one and the first
% (zero) and the last (one) are removed. (def=true)
% OUT:
% d = Estimated intrinsic dimensionality of the dataset.
% ds = Local estimations.
%
% Examples
% --------
% X = randsphere(30,1000,1);
% V = linSubspSpanOrthonormalize(randn(120,30));
% pts = V*X;
% [inds,dists] = KNN(pts,12,true);
% d = MLE(pts,'dists',dists);
% Checking parameters:
if nargin<1; error('A dataset is required'); end
% Infos:
N = size(data,2);
% Default parameters:
params = struct('k',{10},'normalized',{true});
% Parsing named parameters:
if nargin > 1
try
params = parseParamsNamed(varargin, false, params);
catch e
error('Pairs ''name,value'' are required as arguments after the dataset.');
end
end
% Is the knn already there?
mustDoKnn = true;
if isfield(params,'dists')
% Extracting and checking:
dists = params.dists;
if size(dists,1) == N && size(dists,2) > 2
% KNN already present:
params.k = size(dists,2)-2;
mustDoKnn = false;
end
end
% Finding the KNNs:
if mustDoKnn
if not(isfield(params,'k'))
error('One parameter of ''dists'' or ''k'' is required.');
end
[~,dists] = KNN(data,params.k,true);
end
% Normalizing the distances:
if not(mustDoKnn || params.normalized)
% Cropping the dists:
dists = dists(:,2:end);
% Normalizing and removing the last element:
dists = dists(:,1:end-1)./repmat(dists(:,end),[1,size(dists,2)-1]);
end
% Local estimation:
ds = -1./mean(log(dists),2);
% Harmonic mean:
d = 1./mean(1./ds);
end
================================================
FILE: analysis/intrinsic_dimension_estimation/matlab_codes/MiND_KL.m
================================================
function [d,kl] = MiND_KL(data,varargin)
%MiND_KL Intrinsic dimensionality using the MiND_KL technique
%
% function [d,kl] = MiND_KL(data, ParamName,ParamValue, ...)
%
% This funciton estimates the intrinsic dimensionality of a dataset using
% the Minimum Neighbor Distance estimators described in:
%
% "Minimum Neighbor Distance Estimators of Intrinsic Dimension",
% A.Rozza, G.Lombardi, C.Ceruti, E.Casiraghi, P.Campadelli,
% Published to the European Conference of Machine Learning (ECML 2011).
%
% Parameters
% ----------
% IN:
% data = Matrix DxN containing the dataset as columnwise points.
% Named parameters:
% 'k'
% The number of neighbours to be used, must be >= 1. (def=10)
% 'minDim'
% The minimal number of dimensions. (def=1)
% 'maxDim'
% The maximal number of dimensions. (def=size(data,1))
% 'sigma'
% Standard deviation of the gaussian smoothing filter executed on
% the estimated divergences. (def=[]=no smoothing)
% 'dists'
% Precomputed knn distanced as returned by the KNN function for k
% set to k+2 with respect to the given k parameter (k is set to
% "size(dists,2)-2" if knn is present).
% 'normalized'
% Are the dists already normalized? Normalized here means that all
% the k+2 distances are normalized wrt the last one and the first
% (zero) and the last (one) are removed. (def=true)
% OUT:
% d = Estimated intrinsic dimensionality of the dataset.
% kl = The Kullback Leibler divergences curve for d=1..D.
%
% Examples
% --------
% X = randsphere(30,1000,1);
% V = linSubspSpanOrthonormalize(randn(120,30));
% pts = V*X;
% [inds,dists] = KNN(pts,12,true);
% [d,kl] = MiND_KL(pts,'dists',dists);
% Checking parameters:
if nargin<1; error('A dataset is required'); end
% Infos:
[D,N] = size(data);
% Default parameters:
params = struct('k',{10}, ...
'normalized',{true}, ...
'minDim',{1}, 'maxDim',{D}, ...
'sigma',[]);
% Parsing named parameters:
if nargin > 1
try
params = parseParamsNamed(varargin, false, params);
catch e
error('Pairs ''name,value'' are required as arguments after the dataset.');
end
end
% Is the knn already there?
mustDoKnn = true;
if isfield(params,'dists')
% Extracting and checking:
dists = params.dists;
if size(dists,1) == N && size(dists,2) > 2
% KNN already present:
params.k = size(dists,2)-2;
mustDoKnn = false;
end
end
% Finding the KNNs:
if mustDoKnn
if not(isfield(params,'k'))
error('One parameter of ''dists'' or ''k'' is required.');
end
[~,dists] = KNN(data,params.k+2);
end
% Normalizing the distances:
if mustDoKnn || not(params.normalized)
% Cropping the dists:
dists = dists(:,2:end);
% Normalizing and removing the last element:
dists = dists(:,1:end-1)./repmat(dists(:,end),[1,size(dists,2)-1]);
end
% Only the first normalized neighbour is required:
k = size(dists,2);
dists = dists(:,1)';
% Executing the experiments with unit uniformly-sampled hyper-spheres:
numDims = params.maxDim - params.minDim + 1;
kl = zeros(1,numDims);
for i=1:numDims
% Generating the data:
data2 = randsphere(params.minDim + i - 1, N);
% Computing the normalized knn distances:
[~,distsExp] = KNN(data2,k,true);
distsExp = distsExp(:,1)';
% Estimating the Kullback-Leibler divergence:
kl(i) = wangKL1d(distsExp,dists);
end
% Smoothing if required:
if not(isempty(params.sigma))
% Preparing the impulse response:
rad = 4*ceil(params.sigma/2);
h = gauss(-rad:rad,params.sigma,0,false);
h = h/sum(h);
% Filtering:
kl = conv([flipdim(kl,2),kl,kl],h,'same');
kl = kl(numDims+1:end-numDims);
end
% Estimating the id:
[~,d] = min(kl);
d = d + params.minDim - 1;
end
% ------------------------ LOCAL FUNCTIONS ------------------------
% Wang's Kullback-Leibler divergence estimator in one dimension:
function div = wangKL1d(data1,data2)
% Getting data:
n = numel(data1);
m = numel(data2);
data1 = sort(data1);
% Generating the nearest neighbors for data1:
nnData1 = [-inf,data1,inf];
nnData1 = abs(nnData1(1:end-1)-nnData1(2:end));
nnData1 = min([nnData1(1:end-1);nnData1(2:end)],[],1);
% Generating the nearest neighbors for data2:
nnData2 = min(abs(repmat(data1,[m,1]) - repmat(data2',[1,n])),[],1);
% Computing the kl div:
div = abs(sum(log(nnData2./(nnData1+eps)))/n + log(m/(n-1)));
end
================================================
FILE: analysis/intrinsic_dimension_estimation/matlab_codes/MiND_ML.m
================================================
function d = MiND_ML(data,varargin)
%MiND_ML Intrinsic dimensionality using the MiND_ML techniques
%
% function d = MiND_ML(data, ParamName,ParamValue, ...)
%
% This funciton estimates the intrinsic dimensionality of a dataset using
% the Minimum Neighbor Distance estimators described in:
%
% "Minimum Neighbor Distance Estimators of Intrinsic Dimension",
% A.Rozza, G.Lombardi, C.Ceruti, E.Casiraghi, P.Campadelli,
% Published to the European Conference of Machine Learning (ECML 2011).
%
% Parameters
% ----------
% IN:
% data = Matrix DxN containing the dataset as columnwise points.
% Named parameters:
% 'k'
% The number of neighbours to be used, must be >= 1. (def=10)
% 'optimize'
% Must an optimization step be executed to estimate a fractal id,
% that is MiND_MLK instead of MiND_MLI? (def=false)
% 'dists'
% Precomputed knn distanced as returned by the KNN function for k
% set to k+2 with respect to the given k parameter (k is set to
% "size(dists,2)-2" if knn is present).
% 'normalized'
% Are the dists already normalized? Normalized here means that all
% the k+2 distances are normalized wrt the last one and the first
% (zero) and the last (one) are removed. (def=true)
% OUT:
% d = Estimated intrinsic dimensionality of the dataset.
%
% Examples
% --------
% X = randsphere(30,1000,1);
% V = linSubspSpanOrthonormalize(randn(120,30));
% pts = V*X;
% [inds,dists] = KNN(pts,12,true);
% d = MiND_ML(pts,'dists',dists);
% Checking parameters:
if nargin<1; error('A dataset is required'); end
% Infos:
[D,N] = size(data);
% Default parameters:
params = struct('k',{10},'optimize',{false},'normalized',{true});
% Parsing named parameters:
if nargin > 1
try
params = parseParamsNamed(varargin, false, params);
catch e
error('Pairs ''name,value'' are required as arguments after the dataset.');
end
end
% Is the knn already there?
mustDoKnn = true;
if isfield(params,'dists')
% Extracting and checking:
dists = params.dists;
if size(dists,1) == N && size(dists,2) > 2
% KNN already present:
params.k = size(dists,2)-2;
mustDoKnn = false;
end
end
% Finding the KNNs:
if mustDoKnn
if not(isfield(params,'k'))
error('One parameter of ''dists'' or ''k'' is required.');
end
[~,dists] = KNN(data,params.k,true);
end
% Normalizing the distances:
if not(mustDoKnn || params.normalized)
% Cropping the dists:
dists = dists(:,2:end);
% Normalizing and removing the last element:
dists = dists(:,1:end-1)./repmat(dists(:,end),[1,size(dists,2)-1]);
end
% Infos:
[n,k] = size(dists);
% The first normalized distance:
r = dists(:,1);
% The log likelihood derivative squared:
fun = @(d)((n/d)+sum(log(r)-((k-1).*(log(r).*r.^d))./(1-r.^d)))^2;
% Evaluating it:
vals = zeros(1,D);
for d=1:D; vals(d)=fun(d); end
% MiND_MLI result:
[~,d] = min(vals);
% The optimization:
if params.optimize
opts = optimset('Display','off','TolX',1e-3);
d = fminsearch(fun,d,opts);
end
end
================================================
FILE: analysis/intrinsic_dimension_estimation/matlab_codes/demo_idEstimation.m
================================================
%% Intrinsic Dimensionality (id) estimation
%
% In the past decade the development of automatic techniques to estimate
% the intrinsic dimensionality of a given dataset has gained considerable
% attention due to its relevance in several application elds.
%
% In this small toolbox some of the state-of-art techniques are implemented
% as functions that use a "standard" interface to use them, thus allowing
% client code to be decoupled from the knowledge of the applied technique.
%
% For more details see/cite:
%
% "Minimum Neighbor Distance Estimators of Intrinsic Dimension",
% A.Rozza, G.Lombardi, C.Ceruti, E.Casiraghi, P.Campadelli,
% Published to the European Conference of Machine Learning (ECML 2011).
%
% "DANCo: Dimensionality from Angle and Norm Concentration",
% C. Ceruti, S. Bassis, A. Rozza, G. Lombardi, E. Casiraghi, P. Campadelli
% arXiv:1206.3881v1, http://arxiv.org/abs/1206.3881v1
%% Usage of this id-estimation toolbox
%
% As an example here we generate a simple dataset and test all the
% implemented techniques on it to compare the results, the dataset is an
% hypersphere in R^{10} linearly embedded in R^{20}.
% Generating the dataset:
d = 10; D = 20;
X = randsphere(d,500);
V = linSubspSpanOrthonormalize(randn(D,d));
data = V*X;
% A first base estimation:
fprintf('MLE => %2.2f\n', MLE(data));
fprintf('DANCoFit => %2.2f\n', DANCoFit(data));
%%
% The techniques:
estimators = {'MLE','MiND_ML','MiND_KL','DANCo','DANCoFit'};
% Initializing the results:
idEst = zeros(1,numel(estimators));
spentTime = zeros(1,numel(estimators));
% Running the experiments:
for i = 1:numel(estimators)
% Obtaining the funciton:
estimator = str2func(estimators{i});
% Estimating:
tic;
idEst(i) = estimator(data);
spentTime(i) = toc;
end
% Plotting the results:
figure; bar([idEst(:),spentTime(:)]); hold on;
plot([0,numel(estimators)+1],[d,d],'b:');
set(gca,'xticklabel',estimators);
legend({'Estimated id','Spent time'},'Location','NorthWest');
title(sprintf('Estimating the id of a %dd hypersphere linearly embedded in R^{%d}',d,D));
%% Estimators based on the Kullback-Leibler divergence
%
% Some of the presented algorithms estimate the Kullback-Leibler divergence
% between the distribution of some statistics computed on the real dataset
% and those computed on synthetically-cenerated data (hyperspheres) or by
% means of fitting functions (trained using the statistics produced by
% synthetic data). Here the results obtained by some of them.
% Initialization:
estimators = {'MiND_KL','DANCo','DANCoFit'};
% Initializing the results:
idEst = zeros(1,numel(estimators));
spentTime = zeros(1,numel(estimators));
kls = zeros(numel(estimators),D);
% Running the experiments:
for i = 1:numel(estimators)
% Obtaining the funciton:
estimator = str2func(estimators{i});
% Estimating:
tic;
[idEst(i),kls(i,:)] = estimator(data);
spentTime(i) = toc;
end
% Notice that DANCo produces unreliable results for d=1:
Ds = 2:D;
kls2 = kls(:,Ds);
% Plotting the results:
figure; plot(Ds,kls2'); hold on;
plot([idEst;idEst],[zeros(size(idEst));max(kls2(:))*ones(size(idEst))],':');
legend(estimators,'Location','NorthEast');
title('KL-based id estimatros: the estimated KLs');
================================================
FILE: analysis/intrinsic_dimension_estimation/matlab_codes/gauss.m
================================================
function func = gauss(x, sigma, mu, norm)
% GAUSS Returns the monodimensional gaussian function
%
% Return the evaluation of the gaussian function on an array of points
%
% Params:
%
% x: The array of x points (def=[-10:10])
% sigma: The standard deviation (def=2.5)
% mu: The mean of th e gaussian (def=0)
% norm: Must be normalized? (def=true)
% Check x
if nargin<1
x = -10:10;
end
% Check sigma
if nargin<2
sigma = 2.5;
end
% Check mean
if nargin<3
mu = 0;
end
% Check norm
if nargin<4
norm = true;
end
% Computing the gaussian funcion
if norm
func = exp(-(x-mu).^2/(2*sigma^2))/(sigma*sqrt(2*pi));
else
func = exp(-(x-mu).^2/(2*sigma^2));
end
================================================
FILE: analysis/intrinsic_dimension_estimation/matlab_codes/html/demo_idEstimation.html
================================================
Intrinsic Dimensionality (id) estimation Intrinsic Dimensionality (id) estimation In the past decade the development of automatic techniques to estimate the intrinsic dimensionality of a given dataset has gained considerable attention due to its relevance in several application elds.
In this small toolbox some of the state-of-art techniques are implemented as functions that use a "standard" interface to use them, thus allowing client code to be decoupled from the knowledge of the applied technique.
For more details see/cite:
"Minimum Neighbor Distance Estimators of Intrinsic Dimension",
A.Rozza, G.Lombardi, C.Ceruti, E.Casiraghi, P.Campadelli,
Published to the European Conference of Machine Learning (ECML 2011). "DANCo: Dimensionality from Angle and Norm Concentration",
C. Ceruti, S. Bassis, A. Rozza, G. Lombardi, E. Casiraghi, P. Campadelli
arXiv:1206.3881v1, http://arxiv.org/abs/1206.3881v1 Contents Usage of this id-estimation toolbox As an example here we generate a simple dataset and test all the implemented techniques on it to compare the results, the dataset is an hypersphere in R^{10} linearly embedded in R^{20}.
d = 10; D = 20;
X = randsphere(d,500);
V = linSubspSpanOrthonormalize(randn(D,d));
data = V*X;
fprintf('MLE => %2.2f\n' , MLE(data));
fprintf('DANCoFit => %2.2f\n' , DANCoFit(data));
MLE => 8.07
DANCoFit => 10.00
estimators = {'MLE' ,'MiND_ML' ,'MiND_KL' ,'DANCo' ,'DANCoFit' };
idEst = zeros(1,numel(estimators));
spentTime = zeros(1,numel(estimators));
for i = 1:numel(estimators)
estimator = str2func(estimators{i});
tic;
idEst(i) = estimator(data);
spentTime(i) = toc;
end
figure; bar([idEst(:),spentTime(:)]); hold on ;
plot([0,numel(estimators)+1],[d,d],'b:' );
set(gca,'xticklabel' ,estimators);
legend({'Estimated id' ,'Spent time' },'Location' ,'NorthWest' );
title(sprintf('Estimating the id of a %dd hypersphere linearly embedded in R^{%d}' ,d,D));
Estimators based on the Kullback-Leibler divergence Some of the presented algorithms estimate the Kullback-Leibler divergence between the distribution of some statistics computed on the real dataset and those computed on synthetically-cenerated data (hyperspheres) or by means of fitting functions (trained using the statistics produced by synthetic data). Here the results obtained by some of them.
estimators = {'MiND_KL' ,'DANCo' ,'DANCoFit' };
idEst = zeros(1,numel(estimators));
spentTime = zeros(1,numel(estimators));
kls = zeros(numel(estimators),D);
for i = 1:numel(estimators)
estimator = str2func(estimators{i});
tic;
[idEst(i),kls(i,:)] = estimator(data);
spentTime(i) = toc;
end
Ds = 2:D;
kls2 = kls(:,Ds);
figure; plot(Ds,kls2'); hold on ;
plot([idEst;idEst],[zeros(size(idEst));max(kls2(:))*ones(size(idEst))],':' );
legend(estimators,'Location' ,'NorthEast' );
title('KL-based id estimatros: the estimated KLs' );
================================================
FILE: analysis/intrinsic_dimension_estimation/matlab_codes/linSubspSpanOrthonormalize.m
================================================
function Vo = linSubspSpanOrthonormalize(V)
% LINSUBSPSPANORTHONORMALIZE Orthonormalizes a linear subspace span matrix.
%
% function Vo = linSubspSpanOrthonormalize(V)
%
% This funciton takes a DxN matrix that represents a set of N linear
% independent vectors of a D-dimensional space, and that identifies an
% N-dimensional linear subspace, and orthonormalizes them so that the
% subspace spanned is the same but the vectors are versors orthogonal one
% each other. The solutions to this problem are infinitely many, the one
% given is the one that normalizes the first vector, orthonormalizes the
% second with respect to the plane containing the first two vectors,
% orthonormalizes the third with respect to the 3D space spanned by the
% first three vectors and so on. Obviously must be 1<=N<=D.
%
% Parameters
% ----------
% IN:
% V = The DxN matrix containing the N vectors.
% OUT:
% Vo = The orthonormalized linear subspace base.
%
% Pre
% ---
% - The columns of V must be linear independent (i.e. rank(V')==N).
% - The columns of V must be >=1 and <=D.
%
% Post
% ----
% - The returned Vo contains columns that are normal
% (i.e. norm(V(:,i)==1).
% - The returned Vo contains columns that are orthogonal
% (i.e. dot(V(:,i),V(:,j))==0 for i~=j).
% Check params:
if size(V,2)<1 || size(V,2)>size(V,1)
error('N columns must be >=1 and <=D rows!');
end
% Check for linearly dependence:
if rank(V')~=size(V,2)
error('The V vectors must be linearly independent!');
end
% Normalizing the V vectors:
normV = sqrt(sum(V.^2));
normV(normV==0) = eps;
V = V./repmat(normV,[size(V,1),1]);
% Orthogonalizing:
for i=2:size(V,2)
% Computing the projection on the subspace:
proj = V(:,1)*dot(V(:,1),V(:,i));
for j=2:i-1
% The projection of the Vi vector over the Vj orthonormal versor:
proj = proj + V(:,j)*dot(V(:,j),V(:,i));
end
% Removing and normalizing:
V(:,i) = V(:,i) - proj;
V(:,i) = V(:,i)./norm(V(:,i));
end
% Returning:
Vo = V;
================================================
FILE: analysis/intrinsic_dimension_estimation/matlab_codes/parseParamsNamed.m
================================================
function pars = parseParamsNamed(parsCell,recurse,parsi,multi)
% PARSEPARAMSNAMED Getting named parameters as a structure.
%
% function pars = parseParamsNamed(parsCell,recurse,parsi,multi)
%
% This function allows to get named parameters as a struct object. Named
% parameters are sequences of strings and values pair where the string
% represents the name of the parameter and the value its value.
%
% Parameters
% ----------
% IN:
% parsCell = The cell containing the parameters (like varargin).
% recurse = Must the parsing recurse in sub-cells? (def=false)
% parsi = Initial struct (with default values). (def=struct)
% multi = Must the multi-valued fields be grouped in a cell? (def=false)
% OUT:
% pars = The structure containing the parameters.
% Chck parameters:
if nargin<1 || ~isa(parsCell,'cell')
error('A cell with parameters must be passed as parameter!');
end
if nargin<2; recurse=false; end
if nargin<3; parsi=struct; end
if nargin<4; multi=false; end
% Check for the size:
N = numel(parsCell);
if mod(N,2)~=0
error('The cell must contain an even number of elements!');
end
% Init:
pars = parsi;
% Iterating on parameters:
for i=1:2:N
% Get the name:
name = parsCell{i};
if ~isa(name,'char')
error('The name of parameters must be a string!');
end
% Getting the value:
value = parsCell{i+1};
if recurse && isa(value,'cell')
% Generating the substructure:
value = parseParamsNamed(value,recurse);
end
% Checking if multi:
if multi
% Init:
if not(isfield(pars,name))
pars.(name) = {};
end
if not(isa(pars.(name),'cell'))
pars.(name) = {pars.(name)};
end
% Packing:
pars.(name){end+1} = value;
else
% Saving the last:
pars.(name) = value;
end
end
================================================
FILE: analysis/intrinsic_dimension_estimation/matlab_codes/private/DANCo_estimateKL.m
================================================
function kl = DANCo_estimateKL(k,dHat,mu,tau,dHat_ref,mu_ref,tau_ref)
%DANCO_ESTIMATEKL Estimating the Kullback-Leibler divergence for DANCo
%
% function kl = DANCo_estimateKL(k,dHat,mu,tau,dHat_ref,mu_ref,tau_ref)
%
% This function estimates the Kullback-Leibler divergence between the
% distances/angles bivariate distributions over two datasets whose
% statistic parameters are {dHat,mu,tau} and {dHat_ref,mu_ref,tau_ref}
% respectively.
%
% Parameters
% ----------
% IN:
% k = The KNN parameter.
% dHat,mu,tau = Statistics for the real dataset.
% dHat_ref,mu_ref,tau_ref = Statistics for the synthetic dataset.
% Checking parameters:
if nargin<7; error('Too few parameters'); end
% Doing the computation:
kl = EstimateKL_dists(dHat,dHat_ref,k) + ...
EstimateKL_angles(mu,tau,mu_ref,tau_ref);
end
% ------------------------ LOCAL FUNCTIONS ------------------------
% Estimating the KL-divergence for all the cases d=1:D:
function kl = EstimateKL_angles(m1,k1,m2,k2)
% The bessel values:
bk2 = besseli(0,k2);
bk1 = besseli(0,k1);
ak1 = besseli(1,k1);
ck1 = besseli(1,-k1);
% Checking for infs:
if(isinf(k1)); k1 = realmax; end
k2(isinf(k2)) = realmax;
bk2(isinf(bk2)) = realmax;
if(isinf(bk1)); bk1 = realmax; end
if(isinf(ak1)); ak1 = realmax; end
if(isinf(ck1)); ck1 = realmax; end
% Computing the KL:
kl = abs(log(bk2./bk1) + (ak1-ck1)./(2*bk1).*(k1-k2.*cos(m2-m1)));
end
% -----------------------------------------------------------------
% Estimating the KL-divergence for the distances:
function kl = EstimateKL_dists(d1,d2,k)
% Initializing:
bin = zeros(numel(d2),k+1);
% Computing the terms in the summary:
for i=1:k+1
bin(:,i) = (-1)^(i-1)*nchoosek(k,i-1).*psi(1+(i-1)*(d1./d2));
end
% Estimating the KL in cloised form:
kl = -(1+harmonic(k-1)-harmonic(k)*(d2./d1)+log(d2./d1)+(k-1)*sum(bin,2)');
end
% -----------------------------------------------------------------
% Harmonic numbers (by David Terr):
function h = harmonic(z)
if z == 1
h = 1;
else
h = log(z) + 0.5772196649 + 1/(2*z) - 1/(12*z^2) + 1/(120*z^4) - 1/(252*z^6);
end
end
================================================
FILE: analysis/intrinsic_dimension_estimation/matlab_codes/private/DANCo_statistics.m
================================================
function [isLowDim,dHat,mu,tau] = DANCo_statistics(data,k,params,force)
%DANCO_STATISTICS Computes the statistics used by DANCo
%
% function [isLowDim,dHat,mu,tau] = DANCo_statistics(data,k,params,force)
%
% This funciton allows to compute the statistic parameters used by DANCo
% to estimate the intrinsic dimensionality, in cas of isLowDim==true only
% the dHat value is computed (the other output values set to 0 in this
% case).
%
% Parameters
% ----------
% IN:
% data = The dataset as a matrix with columnwise vectors.
% k = KNN parameter.
% params = The named parameters passed to DANCo (see DANCo or DANCoFit).
% force = Forces the computation of mu and tau. (def=false).
% OUT:
% isLowDim = True if the intrinsic dimensionality is considered low, in
% this case the dHat estimation is considered good enough.
% dHat = Estimated id by means of biased esimators like MLE or MiND.
% mu,tau = Von Mises circular statistic parameters.
% Checking parameters:
if nargin<3; error('A dataset, k, and the parameters are required'); end
if nargin<4; force=false; end
if not(isa(params,'struct')); error('Params must be a struct'); end
if not(isfield(params,'mlMethod')); error('An mlMethod is required'); end
if not(isfield(params,'minCorrectionDim')); error('A minCorrectionDim is required'); end
% Infos:
N = size(data,2);
% Is the knn already there?
mustDoKnn = true;
if isfield(params,'inds') && isfield(params,'dists')
% Extracting and checking:
inds = params.inds;
dists = params.dists;
if size(dists,1) == N && size(dists,2) >= k+2 && size(inds,1) == N && size(inds,2) >= k+2
% KNN already present:
dists = dists(:,1:k+2);
inds = inds(:,1:k+2);
mustDoKnn = false;
end
end
% Finding the KNNs only if needed:
if mustDoKnn
% It's needed!
[inds,dists] = KNN(data,k,true);
end
% Normalizing the distances:
if not(mustDoKnn || (isfield(params,'normalized') && params.normalized))
% Cropping both dists and inds:
dists = dists(:,2:end);
inds = inds(:,2:end);
% Normalizing and removing the last element:
dists = dists(:,1:end-1)./repmat(dists(:,end),[1,size(dists,2)-1]);
end
% Estimating dHat:
switch params.mlMethod
case 'MLE'
dHat = MLE(data,'dists',dists,'normalized',true);
case 'MiND_MLI'
dHat = MiND_ML(data,'dists',dists,'optimize',false,'normalized',true);
case 'MiND_MLK'
dHat = MiND_ML(data,'dists',dists,'optimize',true,'normalized',true);
otherwise
error(['Unknown mlMethod: ' params.mlMethod]);
end
% Correction based on the KL-divergence:
isLowDim = dHat < params.minCorrectionDim;
if not(isLowDim) || force
% The circular statistics:
angles = ComputeAngles(data,inds);
[mu,tau] = CircStats(angles);
% Averaging the parameters:
mu = angle(sum(exp(1i*mu)));
tau = mean(tau);
else
% Fake parameters:
mu = 0;
tau = 0;
end
end
% ------------------------ LOCAL FUNCTIONS ------------------------
% Computing the used angular statistics:
function [mu,tau] = CircStats(angles)
% Computing the complex circular means:
cm = sum(exp(1i*angles),2);
% Computing the von Mises parameters:
mu = angle(cm);
r = abs(cm)/size(angles,2);
% The R function:
if r < 0.53; tau = 2*r + r.^3 + 5*r.^5./6; else
if r < 0.85; tau = -0.4 + 1.39*r + 0.43./(1-r);
else tau = 1./(r.^3 - 4*r.^2 + 3*r); end
end
% Correcting for small sample sizes:
if size(angles,2) < 15
mask = tau < 2;
tau(mask) = max(tau(mask)-2./(size(angles,2).*tau(mask)),0);
tau(~mask) = (size(angles,2)-1).^3.*tau(~mask)./(size(angles,2).^3+size(angles,2));
end
end
% -----------------------------------------------------------------
% Computing the pairwase angles for neighbours:
function angles = ComputeAngles(data,inds)
% Init:
K = size(inds,2);
angles = zeros(size(inds,1),K*(K-1)/2);
cnk = combnk(1:size(inds,2),2);
% Iterating on points:
for i=1:size(data,2)
% Translating and normalizing points:
pts = data(:,inds(i,:)) - repmat(data(:,i),[1,K]);
pts = pts./repmat(sqrt(sum(pts.^2,1)),[size(pts,1),1]);
% Computing all the pairwise angles:
a = pts(:,cnk(:,1)); b = pts(:,cnk(:,2));
angles(i,:) = atan2(sqrt(sum((b-repmat(sum(a.*b,1),[size(a,1),1]).*a).^2,1)),sum(a.*b,1));
end
end
================================================
FILE: analysis/intrinsic_dimension_estimation/matlab_codes/randsphere.m
================================================
function X = randsphere(d,N,r)
% RANDSPHERE Random points uniformly drawn from a spere with radius r
%
% function X = randsphere(d,N,r)
%
% This function generates random data samples from a uniform hypersphere
% with a given radius. The algorithm is the one reported by Roger Stafford
% on matlabcentral/fileexchange.
%
% Parameters
% ----------
% IN:
% d = Space dimensionality. (def=2)
% N = Number of points. (def=1000)
% r = Hypersphere radius. (def=1)
% OUT:
% X = Matrix of points (columnwise).
% Checking parameters:
if nargin<1; d=2; end
if nargin<2; N=1000; end
if nargin<3; r=1; end
% Gaussian randomly sampled points:
X = randn(d,N);
% Squared norms:
s2 = sum(X.^2,1);
% Correcting the points norms:
X = X.*repmat(r*(gammainc(s2/2,d/2).^(1/d))./sqrt(s2),d,1);
end
================================================
FILE: analysis/intrinsic_dimension_estimation/methods.py
================================================
import numpy as np
import os
from sklearn.neighbors import NearestNeighbors
def kNN(X, n_neighbors, n_jobs):
neigh = NearestNeighbors(n_neighbors=n_neighbors, n_jobs=n_jobs).fit(X)
dists, inds = neigh.kneighbors(X)
return dists, inds
def Levina_Bickel(X, dists, k):
m = np.log(dists[:, k:k+1] / dists[:, 1:k])
m = (k-2) / np.sum(m, axis=1)
dim = np.mean(m)
return dim
def start_matlab_engine():
import matlab.engine
eng = matlab.engine.start_matlab()
eng.addpath(os.path.join(os.path.split(__file__)[0], 'matlab_codes'))
return eng
def MiND_ML(X, dists, k):
import matlab.engine
eng = start_matlab_engine()
X_mat = matlab.double(X.T.tolist())
dists_mat = matlab.double(dists[:, :k+2].tolist())
dim = eng.MiND_ML(X_mat, 'dists', dists_mat, 'normalized', False, 'optimize', True)
return dim
def MiND_KL(X, dists, k, maxDim=30):
import matlab.engine
eng = start_matlab_engine()
X_mat = matlab.double(X.T.tolist())
dists_mat = matlab.double(dists[:, :k+2].tolist())
dim = eng.MiND_KL(X_mat, 'k', matlab.double([k]), 'maxDim', matlab.double([maxDim]),
'dists', dists_mat, 'normalized', False, nargout=1)
return dim
def DANCo(X, dists, inds, k, maxDim=30):
import matlab.engine
eng = start_matlab_engine()
X_mat = matlab.double(X.T.tolist())
dists_mat = matlab.double(dists[:, :k+2].tolist())
inds_mat = matlab.int32((inds[:, :k+2]+1).tolist()) # fit Matlab indices
dim = eng.DANCo(X_mat, 'k', matlab.double([k]), 'maxDim', matlab.double([maxDim]), 'fractal', True,
'dists', dists_mat, 'inds', inds_mat, 'normalized', False, nargout=1)
return dim
def Hein_CD(X):
import matlab.engine
eng = start_matlab_engine()
X_mat = matlab.double(X.T.tolist())
dim = eng.GetDim(X_mat, nargout=1)
dim = np.array(dim)[0]
return dim[0], dim[1]
================================================
FILE: analysis/latent_regression/__init__.py
================================================
from .regressors import *
================================================
FILE: analysis/latent_regression/regressors.py
================================================
import numpy as np
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.neural_network import MLPRegressor
from sklearn.linear_model import LinearRegression
def mlp_regress(X_train, X_test, y_train, y_test, random_state=0):
scaler = StandardScaler()
nn = MLPRegressor(hidden_layer_sizes=(64, 64, 64, 64, 64), activation='relu', solver='adam',
learning_rate='adaptive', alpha=0.01, random_state=random_state)
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)
y_norm = np.max(np.abs(y_train))
nn.fit(X_train_scaled, y_train / y_norm)
train_error = np.mean(np.abs(nn.predict(X_train_scaled) * y_norm - y_train))
test_error = np.mean(np.abs(nn.predict(X_test_scaled) * y_norm - y_test))
return train_error, test_error, {'scaler':scaler, 'model':nn}
def lin_regress(X_train, X_test, y_train, y_test):
reg = LinearRegression()
reg.fit(X_train, y_train)
train_error = np.mean(np.abs(reg.predict(X_train) - y_train))
test_error = np.mean(np.abs(reg.predict(X_test) - y_test))
return train_error, test_error, reg
def pca(X_train, X_test, num_components, random_state=0):
pca = PCA(num_components, random_state=random_state)
pca.fit(X_train)
return pca.transform(X_train), pca.transform(X_test), pca
================================================
FILE: configs/air_dancer/latentpred/config1.yaml
================================================
lr: 0.01
seed: 1
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'latent-prediction'
data_filepath: './data/'
dataset: 'air_dancer'
lr_schedule: [20, 50, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/air_dancer/latentpred/config2.yaml
================================================
lr: 0.01
seed: 2
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'latent-prediction'
data_filepath: './data/'
dataset: 'air_dancer'
lr_schedule: [20, 50, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/air_dancer/latentpred/config3.yaml
================================================
lr: 0.02
seed: 3
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'latent-prediction'
data_filepath: './data/'
dataset: 'air_dancer'
lr_schedule: [20, 50, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/air_dancer/model/config1.yaml
================================================
lr: 0.001
seed: 1
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'encoder-decoder'
data_filepath: './data/'
dataset: 'air_dancer'
lr_schedule: [20, 50, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/air_dancer/model/config2.yaml
================================================
lr: 0.001
seed: 2
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'encoder-decoder'
data_filepath: './data/'
dataset: 'air_dancer'
lr_schedule: [20, 50, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/air_dancer/model/config3.yaml
================================================
lr: 0.001
seed: 3
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'encoder-decoder'
data_filepath: './data/'
dataset: 'air_dancer'
lr_schedule: [20, 50, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/air_dancer/model64/config1.yaml
================================================
lr: 0.001
seed: 1
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'encoder-decoder-64'
data_filepath: './data/'
dataset: 'air_dancer'
lr_schedule: [20, 50, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/air_dancer/model64/config2.yaml
================================================
lr: 0.001
seed: 2
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'encoder-decoder-64'
data_filepath: './data/'
dataset: 'air_dancer'
lr_schedule: [20, 50, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/air_dancer/model64/config3.yaml
================================================
lr: 0.001
seed: 3
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'encoder-decoder-64'
data_filepath: './data/'
dataset: 'air_dancer'
lr_schedule: [20, 50, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/air_dancer/refine64/config1.yaml
================================================
lr: 0.0006
seed: 1
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'refine-64'
data_filepath: './data/'
dataset: 'air_dancer'
lr_schedule: [15, 30, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/air_dancer/refine64/config2.yaml
================================================
lr: 0.0003
seed: 2
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'refine-64'
data_filepath: './data/'
dataset: 'air_dancer'
lr_schedule: [15, 30, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/air_dancer/refine64/config3.yaml
================================================
lr: 0.0004
seed: 3
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'refine-64'
data_filepath: './data/'
dataset: 'air_dancer'
lr_schedule: [15, 30, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/circular_motion/model/config1.yaml
================================================
lr: 0.001
seed: 1
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'encoder-decoder'
data_filepath: './data/'
dataset: 'circular_motion'
lr_schedule: [20, 50, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/circular_motion/model/config2.yaml
================================================
lr: 0.001
seed: 2
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'encoder-decoder'
data_filepath: './data/'
dataset: 'circular_motion'
lr_schedule: [20, 50, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/circular_motion/model/config3.yaml
================================================
lr: 0.001
seed: 3
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'encoder-decoder'
data_filepath: './data/'
dataset: 'circular_motion'
lr_schedule: [20, 50, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/circular_motion/model64/config1.yaml
================================================
lr: 0.001
seed: 1
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'encoder-decoder-64'
data_filepath: './data/'
dataset: 'circular_motion'
lr_schedule: [20, 50, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/circular_motion/model64/config2.yaml
================================================
lr: 0.001
seed: 2
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'encoder-decoder-64'
data_filepath: './data/'
dataset: 'circular_motion'
lr_schedule: [20, 50, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/circular_motion/model64/config3.yaml
================================================
lr: 0.001
seed: 3
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'encoder-decoder-64'
data_filepath: './data/'
dataset: 'circular_motion'
lr_schedule: [20, 50, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/circular_motion/refine64/config1.yaml
================================================
lr: 0.0003
seed: 1
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'refine-64'
data_filepath: './data/'
dataset: 'circular_motion'
lr_schedule: [15, 30, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/circular_motion/refine64/config2.yaml
================================================
lr: 0.0003
seed: 2
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'refine-64'
data_filepath: './data/'
dataset: 'circular_motion'
lr_schedule: [15, 30, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/circular_motion/refine64/config3.yaml
================================================
lr: 0.0003
seed: 3
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'refine-64'
data_filepath: './data/'
dataset: 'circular_motion'
lr_schedule: [15, 30, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/double_pendulum/latentpred/config1.yaml
================================================
lr: 0.03
seed: 1
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'latent-prediction'
data_filepath: './data/'
dataset: 'double_pendulum'
lr_schedule: [20, 50, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/double_pendulum/latentpred/config2.yaml
================================================
lr: 0.03
seed: 2
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'latent-prediction'
data_filepath: './data/'
dataset: 'double_pendulum'
lr_schedule: [20, 50, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/double_pendulum/latentpred/config3.yaml
================================================
lr: 0.03
seed: 3
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'latent-prediction'
data_filepath: './data/'
dataset: 'double_pendulum'
lr_schedule: [20, 50, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/double_pendulum/model/config1.yaml
================================================
lr: 0.001
seed: 1
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'encoder-decoder'
data_filepath: './data/'
dataset: 'double_pendulum'
lr_schedule: [20, 50, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/double_pendulum/model/config2.yaml
================================================
lr: 0.001
seed: 2
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'encoder-decoder'
data_filepath: './data/'
dataset: 'double_pendulum'
lr_schedule: [20, 50, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/double_pendulum/model/config3.yaml
================================================
lr: 0.001
seed: 3
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'encoder-decoder'
data_filepath: './data/'
dataset: 'double_pendulum'
lr_schedule: [20, 50, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/double_pendulum/model64/config1.yaml
================================================
lr: 0.001
seed: 1
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'encoder-decoder-64'
data_filepath: './data/'
dataset: 'double_pendulum'
lr_schedule: [20, 50, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/double_pendulum/model64/config2.yaml
================================================
lr: 0.001
seed: 2
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'encoder-decoder-64'
data_filepath: './data/'
dataset: 'double_pendulum'
lr_schedule: [20, 50, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/double_pendulum/model64/config3.yaml
================================================
lr: 0.001
seed: 3
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'encoder-decoder-64'
data_filepath: './data/'
dataset: 'double_pendulum'
lr_schedule: [20, 50, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/double_pendulum/refine64/config1.yaml
================================================
lr: 0.0003
seed: 1
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'refine-64'
data_filepath: './data/'
dataset: 'double_pendulum'
lr_schedule: [15, 30, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/double_pendulum/refine64/config2.yaml
================================================
lr: 0.0005
seed: 2
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'refine-64'
data_filepath: './data/'
dataset: 'double_pendulum'
lr_schedule: [15, 30, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/double_pendulum/refine64/config3.yaml
================================================
lr: 0.0004
seed: 3
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 256
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'refine-64'
data_filepath: './data/'
dataset: 'double_pendulum'
lr_schedule: [15, 30, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/elastic_pendulum/latentpred/config1.yaml
================================================
lr: 0.02
seed: 1
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'latent-prediction'
data_filepath: './data/'
dataset: 'elastic_pendulum'
lr_schedule: [20, 50, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/elastic_pendulum/latentpred/config2.yaml
================================================
lr: 0.02
seed: 2
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'latent-prediction'
data_filepath: './data/'
dataset: 'elastic_pendulum'
lr_schedule: [20, 50, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/elastic_pendulum/latentpred/config3.yaml
================================================
lr: 0.01
seed: 3
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'latent-prediction'
data_filepath: './data/'
dataset: 'elastic_pendulum'
lr_schedule: [20, 50, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/elastic_pendulum/model/config1.yaml
================================================
lr: 0.001
seed: 1
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'encoder-decoder'
data_filepath: './data/'
dataset: 'elastic_pendulum'
lr_schedule: [20, 50, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/elastic_pendulum/model/config2.yaml
================================================
lr: 0.001
seed: 2
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'encoder-decoder'
data_filepath: './data/'
dataset: 'elastic_pendulum'
lr_schedule: [20, 50, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/elastic_pendulum/model/config3.yaml
================================================
lr: 0.001
seed: 3
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'encoder-decoder'
data_filepath: './data/'
dataset: 'elastic_pendulum'
lr_schedule: [20, 50, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/elastic_pendulum/model64/config1.yaml
================================================
lr: 0.001
seed: 1
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'encoder-decoder-64'
data_filepath: './data/'
dataset: 'elastic_pendulum'
lr_schedule: [20, 50, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/elastic_pendulum/model64/config2.yaml
================================================
lr: 0.001
seed: 2
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'encoder-decoder-64'
data_filepath: './data/'
dataset: 'elastic_pendulum'
lr_schedule: [20, 50, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/elastic_pendulum/model64/config3.yaml
================================================
lr: 0.001
seed: 3
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'encoder-decoder-64'
data_filepath: './data/'
dataset: 'elastic_pendulum'
lr_schedule: [20, 50, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/elastic_pendulum/refine64/config1.yaml
================================================
lr: 0.0007
seed: 1
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'refine-64'
data_filepath: './data/'
dataset: 'elastic_pendulum'
lr_schedule: [15, 30, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/elastic_pendulum/refine64/config2.yaml
================================================
lr: 0.0004
seed: 2
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'refine-64'
data_filepath: './data/'
dataset: 'elastic_pendulum'
lr_schedule: [15, 30, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/elastic_pendulum/refine64/config3.yaml
================================================
lr: 0.0002
seed: 3
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 256
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'refine-64'
data_filepath: './data/'
dataset: 'elastic_pendulum'
lr_schedule: [15, 30, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/fire/latentpred/config1.yaml
================================================
lr: 0.008
seed: 1
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'latent-prediction'
data_filepath: './data/'
dataset: 'fire'
lr_schedule: [20, 50, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/fire/latentpred/config2.yaml
================================================
lr: 0.02
seed: 2
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'latent-prediction'
data_filepath: './data/'
dataset: 'fire'
lr_schedule: [20, 50, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/fire/latentpred/config3.yaml
================================================
lr: 0.009
seed: 3
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'latent-prediction'
data_filepath: './data/'
dataset: 'fire'
lr_schedule: [20, 50, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/fire/model/config1.yaml
================================================
lr: 0.001
seed: 1
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'encoder-decoder'
data_filepath: './data/'
dataset: 'fire'
lr_schedule: [20, 50, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/fire/model/config2.yaml
================================================
lr: 0.001
seed: 2
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'encoder-decoder'
data_filepath: './data/'
dataset: 'fire'
lr_schedule: [20, 50, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/fire/model/config3.yaml
================================================
lr: 0.001
seed: 3
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'encoder-decoder'
data_filepath: './data/'
dataset: 'fire'
lr_schedule: [20, 50, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/fire/model64/config1.yaml
================================================
lr: 0.001
seed: 1
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'encoder-decoder-64'
data_filepath: './data/'
dataset: 'fire'
lr_schedule: [20, 50, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/fire/model64/config2.yaml
================================================
lr: 0.001
seed: 2
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'encoder-decoder-64'
data_filepath: './data/'
dataset: 'fire'
lr_schedule: [20, 50, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/fire/model64/config3.yaml
================================================
lr: 0.001
seed: 3
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'encoder-decoder-64'
data_filepath: './data/'
dataset: 'fire'
lr_schedule: [20, 50, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/fire/refine64/config1.yaml
================================================
lr: 0.0007
seed: 1
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'refine-64'
data_filepath: './data/'
dataset: 'fire'
lr_schedule: [15, 30, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/fire/refine64/config2.yaml
================================================
lr: 0.0006
seed: 2
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'refine-64'
data_filepath: './data/'
dataset: 'fire'
lr_schedule: [15, 30, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/fire/refine64/config3.yaml
================================================
lr: 0.0007
seed: 3
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'refine-64'
data_filepath: './data/'
dataset: 'fire'
lr_schedule: [15, 30, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/lava_lamp/latentpred/config1.yaml
================================================
lr: 0.008
seed: 1
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'latent-prediction'
data_filepath: './data/'
dataset: 'lava_lamp'
lr_schedule: [20, 50, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/lava_lamp/latentpred/config2.yaml
================================================
lr: 0.009
seed: 2
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'latent-prediction'
data_filepath: './data/'
dataset: 'lava_lamp'
lr_schedule: [20, 50, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/lava_lamp/latentpred/config3.yaml
================================================
lr: 0.01
seed: 3
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'latent-prediction'
data_filepath: './data/'
dataset: 'lava_lamp'
lr_schedule: [20, 50, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/lava_lamp/model/config1.yaml
================================================
lr: 0.001
seed: 1
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'encoder-decoder'
data_filepath: './data/'
dataset: 'lava_lamp'
lr_schedule: [20, 50, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/lava_lamp/model/config2.yaml
================================================
lr: 0.001
seed: 2
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'encoder-decoder'
data_filepath: './data/'
dataset: 'lava_lamp'
lr_schedule: [20, 50, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/lava_lamp/model/config3.yaml
================================================
lr: 0.001
seed: 3
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'encoder-decoder'
data_filepath: './data/'
dataset: 'lava_lamp'
lr_schedule: [20, 50, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/lava_lamp/model64/config1.yaml
================================================
lr: 0.001
seed: 1
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'encoder-decoder-64'
data_filepath: './data/'
dataset: 'lava_lamp'
lr_schedule: [20, 50, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/lava_lamp/model64/config2.yaml
================================================
lr: 0.001
seed: 2
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'encoder-decoder-64'
data_filepath: './data/'
dataset: 'lava_lamp'
lr_schedule: [20, 50, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/lava_lamp/model64/config3.yaml
================================================
lr: 0.001
seed: 3
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'encoder-decoder-64'
data_filepath: './data/'
dataset: 'lava_lamp'
lr_schedule: [20, 50, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/lava_lamp/refine64/config1.yaml
================================================
lr: 0.0006
seed: 1
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'refine-64'
data_filepath: './data/'
dataset: 'lava_lamp'
lr_schedule: [15, 30, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/lava_lamp/refine64/config2.yaml
================================================
lr: 0.0004
seed: 2
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'refine-64'
data_filepath: './data/'
dataset: 'lava_lamp'
lr_schedule: [15, 30, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/lava_lamp/refine64/config3.yaml
================================================
lr: 0.0005
seed: 3
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'refine-64'
data_filepath: './data/'
dataset: 'lava_lamp'
lr_schedule: [15, 30, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/reaction_diffusion/latentpred/config1.yaml
================================================
lr: 0.01
seed: 1
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'latent-prediction'
data_filepath: './data/'
dataset: 'reaction_diffusion'
lr_schedule: [20, 50, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/reaction_diffusion/latentpred/config2.yaml
================================================
lr: 0.03
seed: 2
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'latent-prediction'
data_filepath: './data/'
dataset: 'reaction_diffusion'
lr_schedule: [20, 50, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/reaction_diffusion/latentpred/config3.yaml
================================================
lr: 0.01
seed: 3
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'latent-prediction'
data_filepath: './data/'
dataset: 'reaction_diffusion'
lr_schedule: [20, 50, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/reaction_diffusion/model/config1.yaml
================================================
lr: 0.001
seed: 1
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'encoder-decoder'
data_filepath: './data/'
dataset: 'reaction_diffusion'
lr_schedule: [20, 50, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/reaction_diffusion/model/config2.yaml
================================================
lr: 0.001
seed: 2
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'encoder-decoder'
data_filepath: './data/'
dataset: 'reaction_diffusion'
lr_schedule: [20, 50, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/reaction_diffusion/model/config3.yaml
================================================
lr: 0.001
seed: 3
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'encoder-decoder'
data_filepath: './data/'
dataset: 'reaction_diffusion'
lr_schedule: [20, 50, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/reaction_diffusion/model64/config1.yaml
================================================
lr: 0.001
seed: 1
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'encoder-decoder-64'
data_filepath: './data/'
dataset: 'reaction_diffusion'
lr_schedule: [20, 50, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/reaction_diffusion/model64/config2.yaml
================================================
lr: 0.001
seed: 2
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'encoder-decoder-64'
data_filepath: './data/'
dataset: 'reaction_diffusion'
lr_schedule: [20, 50, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/reaction_diffusion/model64/config3.yaml
================================================
lr: 0.001
seed: 3
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'encoder-decoder-64'
data_filepath: './data/'
dataset: 'reaction_diffusion'
lr_schedule: [20, 50, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/reaction_diffusion/refine64/config1.yaml
================================================
lr: 0.0004
seed: 1
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'refine-64'
data_filepath: './data/'
dataset: 'reaction_diffusion'
lr_schedule: [15, 30, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/reaction_diffusion/refine64/config2.yaml
================================================
lr: 0.0003
seed: 2
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'refine-64'
data_filepath: './data/'
dataset: 'reaction_diffusion'
lr_schedule: [15, 30, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/reaction_diffusion/refine64/config3.yaml
================================================
lr: 0.0003
seed: 3
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 256
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'refine-64'
data_filepath: './data/'
dataset: 'reaction_diffusion'
lr_schedule: [15, 30, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/single_pendulum/latentpred/config1.yaml
================================================
lr: 0.005
seed: 1
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'latent-prediction'
data_filepath: './data/'
dataset: 'single_pendulum'
lr_schedule: [20, 50, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/single_pendulum/latentpred/config2.yaml
================================================
lr: 0.01
seed: 2
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'latent-prediction'
data_filepath: './data/'
dataset: 'single_pendulum'
lr_schedule: [20, 50, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/single_pendulum/latentpred/config3.yaml
================================================
lr: 0.02
seed: 3
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'latent-prediction'
data_filepath: './data/'
dataset: 'single_pendulum'
lr_schedule: [20, 50, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/single_pendulum/model/config1.yaml
================================================
lr: 0.001
seed: 1
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'encoder-decoder'
data_filepath: './data/'
dataset: 'single_pendulum'
lr_schedule: [20, 50, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/single_pendulum/model/config2.yaml
================================================
lr: 0.001
seed: 2
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'encoder-decoder'
data_filepath: './data/'
dataset: 'single_pendulum'
lr_schedule: [20, 50, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/single_pendulum/model/config3.yaml
================================================
lr: 0.001
seed: 3
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'encoder-decoder'
data_filepath: './data/'
dataset: 'single_pendulum'
lr_schedule: [20, 50, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/single_pendulum/model64/config1.yaml
================================================
lr: 0.001
seed: 1
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'encoder-decoder-64'
data_filepath: './data/'
dataset: 'single_pendulum'
lr_schedule: [20, 50, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/single_pendulum/model64/config2.yaml
================================================
lr: 0.001
seed: 2
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'encoder-decoder-64'
data_filepath: './data/'
dataset: 'single_pendulum'
lr_schedule: [20, 50, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/single_pendulum/model64/config3.yaml
================================================
lr: 0.001
seed: 3
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'encoder-decoder-64'
data_filepath: './data/'
dataset: 'single_pendulum'
lr_schedule: [20, 50, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/single_pendulum/refine64/config1.yaml
================================================
lr: 0.0003
seed: 1
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'refine-64'
data_filepath: './data/'
dataset: 'single_pendulum'
lr_schedule: [15, 30, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/single_pendulum/refine64/config2.yaml
================================================
lr: 0.0003
seed: 2
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'refine-64'
data_filepath: './data/'
dataset: 'single_pendulum'
lr_schedule: [15, 30, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/single_pendulum/refine64/config3.yaml
================================================
lr: 0.0004
seed: 3
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'refine-64'
data_filepath: './data/'
dataset: 'single_pendulum'
lr_schedule: [15, 30, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/swingstick_magnetic/model/config1.yaml
================================================
lr: 0.001
seed: 1
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'encoder-decoder'
data_filepath: './data/'
dataset: 'swingstick_magnetic'
lr_schedule: [20, 50, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/swingstick_magnetic/model/config2.yaml
================================================
lr: 0.001
seed: 2
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'encoder-decoder'
data_filepath: './data/'
dataset: 'swingstick_magnetic'
lr_schedule: [20, 50, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/swingstick_magnetic/model/config3.yaml
================================================
lr: 0.001
seed: 3
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'encoder-decoder'
data_filepath: './data/'
dataset: 'swingstick_magnetic'
lr_schedule: [20, 50, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/swingstick_magnetic/model64/config1.yaml
================================================
lr: 0.001
seed: 1
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'encoder-decoder-64'
data_filepath: './data/'
dataset: 'swingstick_magnetic'
lr_schedule: [20, 50, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/swingstick_magnetic/model64/config2.yaml
================================================
lr: 0.001
seed: 2
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'encoder-decoder-64'
data_filepath: './data/'
dataset: 'swingstick_magnetic'
lr_schedule: [20, 50, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/swingstick_magnetic/model64/config3.yaml
================================================
lr: 0.001
seed: 3
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'encoder-decoder-64'
data_filepath: './data/'
dataset: 'swingstick_magnetic'
lr_schedule: [20, 50, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/swingstick_magnetic/refine64/config1.yaml
================================================
lr: 0.001
seed: 1
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'refine-64'
data_filepath: './data/'
dataset: 'swingstick_magnetic'
lr_schedule: [15, 30, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/swingstick_magnetic/refine64/config2.yaml
================================================
lr: 0.001
seed: 2
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'refine-64'
data_filepath: './data/'
dataset: 'swingstick_magnetic'
lr_schedule: [15, 30, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/swingstick_magnetic/refine64/config3.yaml
================================================
lr: 0.001
seed: 3
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'refine-64'
data_filepath: './data/'
dataset: 'swingstick_magnetic'
lr_schedule: [15, 30, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/swingstick_non_magnetic/latentpred/config1.yaml
================================================
lr: 0.02
seed: 1
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'latent-prediction'
data_filepath: './data/'
dataset: 'swingstick_non_magnetic'
lr_schedule: [20, 50, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/swingstick_non_magnetic/latentpred/config2.yaml
================================================
lr: 0.03
seed: 2
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'latent-prediction'
data_filepath: './data/'
dataset: 'swingstick_non_magnetic'
lr_schedule: [20, 50, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/swingstick_non_magnetic/latentpred/config3.yaml
================================================
lr: 0.01
seed: 3
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'latent-prediction'
data_filepath: './data/'
dataset: 'swingstick_non_magnetic'
lr_schedule: [20, 50, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/swingstick_non_magnetic/model/config1.yaml
================================================
lr: 0.001
seed: 1
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'encoder-decoder'
data_filepath: './data/'
dataset: 'swingstick_non_magnetic'
lr_schedule: [20, 50, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/swingstick_non_magnetic/model/config2.yaml
================================================
lr: 0.001
seed: 2
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'encoder-decoder'
data_filepath: './data/'
dataset: 'swingstick_non_magnetic'
lr_schedule: [20, 50, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/swingstick_non_magnetic/model/config3.yaml
================================================
lr: 0.001
seed: 3
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'encoder-decoder'
data_filepath: './data/'
dataset: 'swingstick_non_magnetic'
lr_schedule: [20, 50, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/swingstick_non_magnetic/model64/config1.yaml
================================================
lr: 0.001
seed: 1
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'encoder-decoder-64'
data_filepath: './data/'
dataset: 'swingstick_non_magnetic'
lr_schedule: [20, 50, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/swingstick_non_magnetic/model64/config2.yaml
================================================
lr: 0.001
seed: 2
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'encoder-decoder-64'
data_filepath: './data/'
dataset: 'swingstick_non_magnetic'
lr_schedule: [20, 50, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/swingstick_non_magnetic/model64/config3.yaml
================================================
lr: 0.001
seed: 3
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'encoder-decoder-64'
data_filepath: './data/'
dataset: 'swingstick_non_magnetic'
lr_schedule: [20, 50, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/swingstick_non_magnetic/refine64/config1.yaml
================================================
lr: 0.0005
seed: 1
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'refine-64'
data_filepath: './data/'
dataset: 'swingstick_non_magnetic'
lr_schedule: [15, 30, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/swingstick_non_magnetic/refine64/config2.yaml
================================================
lr: 0.0001
seed: 2
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'refine-64'
data_filepath: './data/'
dataset: 'swingstick_non_magnetic'
lr_schedule: [15, 30, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: configs/swingstick_non_magnetic/refine64/config3.yaml
================================================
lr: 0.0004
seed: 3
if_cuda: True
gamma: 0.5
log_dir: 'logs'
train_batch: 512
val_batch: 256
test_batch: 256
num_workers: 8
model_name: 'refine-64'
data_filepath: './data/'
dataset: 'swingstick_non_magnetic'
lr_schedule: [15, 30, 100, 300]
num_gpus: 1
epochs: 1000
================================================
FILE: datainfo/README.md
================================================
## Data Collection
- single_pendulum, elastic_pendulum, reaction_diffusion, circular_motion: simulation codes to generate the simulation data.
- double_pendulum: post processings.
- fire, lava_lamp: split long sequences.
- utils: split long sequences.
## Data Structure for Using Your Own Dataset
All of our datasets follow the structure below. Please prepare any new dataset as follows to use the current dataloader.
If you want to use a different file organization, you will need to write your own dataloader.
```
\double_pendulum
\0
\0.png
\1.png
\...
\1
\...
\2
\...
...
```
================================================
FILE: datainfo/air_dancer/data_split_dict_1.json
================================================
{
"test": [
2,
18,
4
],
"val": [
15,
3,
8
],
"train": [
23,
24,
11,
10,
22,
1,
5,
25,
19,
13,
9,
17,
16,
0,
7,
21,
6,
12,
20,
14
]
}
================================================
FILE: datainfo/air_dancer/data_split_dict_2.json
================================================
{
"test": [
24,
2,
1
],
"val": [
9,
5,
11
],
"train": [
3,
20,
17,
15,
0,
23,
4,
7,
14,
16,
19,
22,
12,
18,
10,
13,
21,
25,
6,
8
]
}
================================================
FILE: datainfo/air_dancer/data_split_dict_3.json
================================================
{
"test": [
17,
18,
7
],
"val": [
19,
11,
4
],
"train": [
6,
16,
23,
14,
1,
5,
21,
10,
9,
13,
25,
12,
3,
8,
22,
20,
0,
2,
24,
15
]
}
================================================
FILE: datainfo/circular_motion/data_split_dict_1.json
================================================
{
"test": [
524,
1032,
960,
907,
977,
840,
877,
802,
392,
5,
746,
980,
623,
1068,
675,
1098,
931,
470,
361,
591,
867,
975,
352,
526,
414,
237,
561,
880,
942,
552,
1059,
789,
12,
1005,
464,
345,
348,
806,
631,
89,
961,
60,
1002,
758,
1046,
335,
221,
898,
177,
767,
751,
354,
848,
827,
497,
983,
70,
805,
1034,
1022,
581,
621,
388,
1039,
1072,
1025,
681,
247,
607,
380,
204,
852,
44,
593,
941,
448,
472,
707,
477,
1015,
896,
454,
59,
864,
443,
780,
18,
52,
45,
62,
650,
209,
468,
545,
912,
4,
886,
798,
58,
999,
192,
429,
777,
967,
920,
1014,
241,
522,
129,
275
],
"val": [
456,
164,
736,
36,
149,
406,
1075,
230,
21,
1017,
442,
620,
214,
1096,
747,
259,
111,
264,
1089,
815,
431,
351,
395,
319,
24,
116,
485,
508,
329,
719,
465,
301,
728,
663,
279,
672,
172,
540,
1063,
163,
171,
71,
297,
1011,
189,
639,
944,
112,
1004,
255,
287,
773,
772,
14,
463,
17,
888,
85,
72,
689,
861,
33,
261,
836,
871,
816,
564,
817,
93,
881,
185,
598,
563,
181,
1079,
235,
823,
28,
614,
469,
339,
627,
1033,
638,
553,
551,
1,
1040,
424,
365,
832,
496,
423,
516,
963,
1019,
567,
583,
373,
890,
492,
57,
972,
436,
210,
574,
796,
531,
132,
828
],
"train": [
391,
479,
51,
1077,
882,
625,
268,
197,
433,
635,
246,
685,
640,
803,
131,
398,
1070,
580,
924,
863,
962,
238,
39,
763,
653,
455,
608,
869,
680,
768,
634,
107,
514,
509,
884,
417,
219,
146,
311,
630,
996,
290,
80,
911,
366,
353,
130,
814,
1001,
765,
100,
565,
487,
498,
158,
147,
862,
887,
13,
529,
295,
444,
916,
795,
722,
260,
1016,
950,
609,
284,
895,
542,
535,
381,
450,
294,
750,
502,
494,
326,
556,
67,
1057,
384,
933,
143,
644,
362,
1084,
120,
16,
1094,
706,
788,
233,
905,
1036,
596,
733,
853,
418,
254,
119,
26,
559,
575,
549,
1088,
834,
966,
202,
434,
1095,
95,
308,
507,
590,
154,
988,
357,
1008,
425,
360,
234,
778,
819,
883,
90,
793,
611,
1028,
901,
812,
740,
432,
923,
68,
289,
167,
679,
784,
1093,
186,
671,
891,
849,
316,
729,
669,
213,
375,
801,
304,
991,
569,
9,
633,
603,
428,
190,
940,
642,
368,
771,
857,
282,
263,
419,
330,
1031,
420,
655,
876,
520,
1041,
544,
194,
791,
81,
122,
859,
3,
656,
1042,
201,
956,
753,
266,
262,
782,
243,
334,
54,
99,
341,
140,
280,
1027,
421,
787,
43,
401,
96,
115,
989,
1071,
973,
56,
734,
1061,
617,
850,
30,
928,
534,
573,
1067,
968,
949,
499,
343,
868,
990,
837,
242,
1092,
482,
1083,
992,
568,
350,
723,
969,
612,
467,
344,
810,
503,
1087,
693,
807,
619,
174,
717,
1076,
594,
29,
624,
586,
718,
336,
809,
926,
616,
441,
409,
1045,
1053,
954,
134,
451,
501,
537,
731,
533,
1010,
1006,
491,
708,
206,
688,
125,
927,
113,
770,
459,
709,
453,
390,
200,
713,
402,
906,
152,
538,
25,
379,
654,
766,
188,
267,
493,
661,
137,
776,
121,
764,
687,
430,
978,
570,
800,
47,
82,
701,
489,
959,
396,
762,
1020,
35,
216,
939,
50,
1007,
1026,
23,
948,
702,
438,
248,
987,
878,
77,
155,
278,
745,
135,
178,
998,
309,
37,
676,
759,
775,
870,
79,
1048,
539,
712,
699,
148,
412,
613,
986,
385,
34,
446,
484,
915,
207,
774,
979,
199,
657,
378,
269,
184,
371,
271,
856,
845,
394,
532,
359,
997,
673,
476,
739,
478,
236,
821,
515,
372,
665,
865,
695,
831,
965,
285,
879,
892,
958,
331,
1000,
666,
1090,
889,
374,
1091,
974,
276,
156,
179,
585,
643,
226,
935,
0,
714,
781,
790,
510,
1049,
257,
223,
447,
293,
651,
530,
825,
69,
332,
370,
1052,
548,
1065,
525,
6,
382,
291,
474,
726,
413,
55,
415,
452,
104,
196,
292,
756,
519,
1078,
215,
647,
716,
913,
779,
437,
1013,
582,
866,
748,
829,
571,
705,
299,
527,
139,
698,
277,
75,
606,
88,
1066,
838,
427,
124,
648,
321,
15,
725,
296,
486,
270,
97,
2,
286,
383,
193,
127,
808,
151,
338,
854,
483,
1080,
1050,
307,
677,
61,
108,
636,
700,
318,
684,
546,
166,
909,
84,
696,
1035,
462,
602,
481,
195,
760,
860,
123,
473,
435,
227,
517,
358,
900,
337,
495,
715,
315,
397,
785,
142,
936,
737,
622,
769,
660,
1074,
168,
393,
356,
403,
1018,
53,
833,
855,
659,
159,
180,
386,
641,
126,
87,
475,
169,
1058,
599,
844,
902,
208,
543,
918,
377,
523,
73,
595,
224,
281,
1024,
422,
1085,
490,
506,
724,
244,
231,
1064,
512,
953,
541,
952,
7,
994,
239,
662,
176,
410,
407,
730,
914,
572,
749,
141,
363,
63,
19,
841,
903,
1043,
1069,
830,
1038,
405,
11,
1023,
597,
449,
670,
678,
562,
440,
457,
683,
839,
253,
49,
161,
550,
946,
306,
182,
820,
566,
323,
32,
558,
145,
211,
1097,
300,
1012,
109,
312,
938,
144,
153,
183,
1037,
826,
743,
652,
513,
943,
157,
674,
505,
367,
921,
692,
22,
76,
847,
757,
930,
752,
804,
249,
20,
872,
250,
94,
792,
649,
1029,
874,
103,
342,
251,
310,
592,
682,
191,
799,
42,
668,
811,
232,
985,
658,
588,
92,
458,
91,
929,
738,
369,
252,
203,
314,
710,
554,
187,
265,
364,
480,
704,
632,
220,
256,
114,
466,
615,
324,
65,
64,
408,
320,
400,
460,
327,
610,
727,
10,
27,
908,
325,
667,
212,
102,
322,
488,
894,
741,
910,
555,
744,
445,
105,
842,
1056,
697,
919,
1009,
118,
165,
899,
600,
245,
897,
754,
843,
971,
947,
932,
686,
628,
1021,
735,
46,
110,
283,
1081,
893,
786,
577,
302,
993,
273,
83,
579,
229,
951,
584,
846,
824,
601,
629,
117,
349,
851,
150,
389,
74,
416,
40,
858,
813,
945,
1062,
797,
500,
732,
618,
783,
298,
904,
1054,
346,
376,
917,
995,
957,
340,
274,
794,
964,
170,
173,
136,
86,
41,
742,
66,
240,
1082,
970,
547,
703,
922,
560,
1051,
98,
818,
272,
218,
439,
347,
138,
576,
1047,
955,
160,
885,
288,
411,
626,
333,
934,
511,
981,
303,
399,
1055,
106,
504,
198,
605,
1073,
690,
587,
1044,
101,
355,
205,
387,
835,
521,
637,
720,
1086,
175,
471,
976,
222,
604,
38,
984,
8,
133,
258,
578,
426,
162,
761,
873,
317,
78,
937,
313,
48,
217,
128,
305,
755,
1030,
875,
646,
1003,
328,
822,
589,
691,
404,
31,
664,
536,
228,
461,
528,
711,
925,
645,
225,
1060,
557,
982,
694,
518,
721
]
}
================================================
FILE: datainfo/circular_motion/data_split_dict_2.json
================================================
{
"test": [
24,
688,
255,
176,
368,
371,
58,
32,
777,
734,
919,
433,
61,
901,
966,
215,
844,
250,
272,
874,
139,
534,
772,
108,
937,
896,
698,
232,
605,
1066,
50,
668,
588,
60,
217,
391,
17,
699,
154,
750,
1001,
425,
638,
832,
1032,
621,
633,
982,
1082,
340,
664,
454,
996,
935,
718,
1062,
931,
1057,
1025,
1020,
571,
1003,
511,
944,
818,
330,
1060,
741,
724,
1079,
849,
912,
372,
1052,
736,
1065,
1044,
279,
355,
665,
361,
48,
472,
483,
363,
336,
867,
778,
652,
952,
745,
56,
1090,
549,
1028,
911,
761,
1042,
805,
882,
324,
73,
434,
515,
631,
346,
739,
173,
187,
115
],
"val": [
810,
639,
233,
1017,
82,
256,
75,
455,
731,
238,
252,
1041,
725,
520,
237,
650,
464,
98,
174,
165,
134,
34,
259,
995,
686,
143,
1049,
716,
18,
1048,
429,
620,
268,
265,
349,
924,
147,
335,
527,
498,
402,
799,
598,
530,
986,
895,
807,
458,
989,
104,
858,
323,
703,
676,
95,
230,
484,
157,
722,
636,
883,
411,
773,
270,
923,
46,
757,
619,
784,
564,
459,
315,
31,
500,
345,
292,
1098,
765,
760,
642,
630,
352,
4,
37,
155,
253,
813,
44,
603,
394,
1,
708,
535,
188,
752,
160,
958,
1033,
130,
261,
382,
21,
940,
746,
41,
25,
69,
977,
117,
84
],
"train": [
521,
97,
212,
436,
465,
835,
410,
195,
591,
288,
988,
828,
450,
127,
42,
93,
92,
970,
873,
239,
424,
1086,
956,
310,
266,
824,
142,
717,
119,
727,
1008,
357,
871,
584,
64,
396,
817,
406,
57,
2,
681,
9,
249,
864,
512,
198,
245,
40,
473,
1089,
299,
182,
601,
379,
321,
955,
763,
486,
766,
22,
906,
744,
171,
316,
702,
467,
370,
898,
274,
820,
1015,
644,
443,
707,
13,
576,
764,
109,
838,
559,
158,
26,
670,
721,
329,
789,
798,
802,
767,
133,
431,
539,
325,
672,
485,
618,
140,
135,
438,
692,
408,
608,
997,
640,
29,
344,
572,
421,
748,
6,
327,
1039,
235,
308,
796,
635,
290,
581,
713,
178,
612,
943,
568,
866,
207,
546,
519,
917,
889,
214,
339,
111,
448,
35,
63,
248,
282,
405,
163,
629,
782,
226,
659,
850,
902,
487,
1058,
592,
540,
596,
505,
831,
801,
575,
1046,
800,
833,
963,
556,
758,
504,
347,
236,
206,
441,
682,
1085,
547,
1040,
507,
87,
1094,
580,
975,
604,
648,
569,
294,
7,
916,
607,
788,
228,
281,
573,
376,
219,
915,
241,
167,
208,
1080,
216,
529,
1075,
94,
690,
907,
542,
578,
146,
120,
968,
49,
862,
211,
680,
447,
815,
586,
1068,
488,
693,
306,
271,
269,
314,
43,
792,
797,
976,
417,
954,
191,
499,
151,
729,
854,
848,
552,
781,
307,
407,
826,
881,
985,
1097,
1069,
397,
423,
508,
168,
76,
998,
627,
439,
860,
258,
229,
105,
645,
590,
751,
377,
759,
791,
597,
80,
106,
887,
891,
386,
753,
1054,
322,
100,
809,
965,
300,
843,
661,
913,
899,
617,
283,
855,
567,
99,
842,
926,
737,
574,
634,
967,
440,
987,
948,
637,
960,
886,
416,
562,
51,
513,
1019,
662,
732,
803,
66,
356,
460,
419,
897,
152,
932,
193,
359,
945,
606,
476,
583,
660,
624,
786,
369,
466,
506,
110,
557,
403,
1012,
351,
378,
1091,
981,
276,
415,
0,
251,
427,
953,
432,
942,
770,
683,
616,
566,
877,
787,
677,
456,
257,
70,
15,
286,
71,
929,
33,
123,
689,
756,
518,
332,
827,
47,
780,
614,
790,
12,
876,
112,
632,
541,
671,
129,
845,
694,
234,
156,
201,
482,
490,
884,
348,
1022,
1021,
302,
380,
921,
738,
231,
1084,
197,
210,
723,
1016,
354,
205,
964,
1038,
517,
755,
593,
1064,
305,
918,
278,
691,
1005,
1006,
242,
675,
200,
27,
116,
492,
430,
1073,
1030,
949,
687,
59,
295,
267,
920,
613,
1009,
398,
227,
880,
890,
90,
1081,
951,
785,
865,
331,
190,
172,
23,
312,
1061,
122,
362,
126,
646,
175,
45,
442,
678,
563,
446,
128,
169,
857,
775,
719,
1070,
991,
806,
623,
754,
77,
615,
684,
145,
463,
199,
304,
714,
666,
816,
863,
277,
868,
223,
333,
841,
878,
532,
808,
469,
39,
183,
85,
179,
53,
14,
65,
861,
704,
565,
1010,
457,
973,
244,
225,
647,
461,
132,
859,
247,
202,
936,
972,
962,
343,
544,
318,
1007,
149,
747,
959,
611,
385,
930,
479,
360,
445,
365,
384,
1092,
328,
287,
350,
409,
387,
743,
555,
1083,
1011,
728,
189,
776,
194,
480,
243,
101,
162,
334,
1018,
643,
649,
137,
847,
298,
452,
749,
170,
994,
554,
730,
79,
1043,
326,
218,
373,
388,
1037,
1045,
62,
284,
971,
705,
836,
888,
983,
622,
91,
516,
153,
543,
957,
837,
892,
933,
922,
673,
319,
341,
311,
651,
264,
783,
309,
551,
742,
136,
510,
879,
946,
453,
577,
654,
550,
900,
814,
969,
641,
570,
289,
10,
1095,
301,
653,
390,
1051,
166,
55,
138,
978,
908,
1034,
1067,
81,
28,
471,
514,
273,
470,
762,
496,
829,
67,
658,
795,
3,
358,
74,
209,
478,
177,
579,
525,
667,
493,
834,
8,
141,
626,
501,
779,
594,
600,
669,
204,
905,
812,
856,
220,
395,
910,
582,
221,
655,
1047,
823,
181,
822,
1026,
846,
320,
11,
894,
1024,
412,
192,
938,
545,
934,
83,
338,
560,
875,
1072,
710,
925,
1002,
414,
280,
494,
54,
522,
999,
184,
392,
609,
30,
451,
477,
148,
587,
367,
413,
240,
1071,
161,
1096,
263,
526,
254,
528,
1074,
720,
872,
553,
159,
114,
293,
536,
992,
449,
337,
711,
903,
774,
399,
735,
131,
38,
663,
697,
88,
68,
426,
150,
558,
180,
364,
72,
1014,
589,
599,
86,
715,
695,
712,
437,
561,
1055,
610,
1078,
974,
852,
769,
303,
696,
503,
909,
1063,
489,
870,
495,
275,
383,
144,
468,
366,
297,
509,
1013,
947,
1050,
481,
674,
1035,
1059,
1087,
709,
840,
400,
656,
342,
1093,
939,
125,
853,
260,
124,
404,
353,
771,
904,
794,
733,
401,
979,
585,
118,
196,
502,
941,
224,
78,
740,
804,
811,
927,
291,
5,
885,
628,
602,
213,
474,
497,
1077,
420,
1029,
462,
16,
990,
793,
203,
313,
851,
103,
422,
839,
1053,
950,
381,
706,
296,
375,
625,
121,
531,
19,
374,
491,
821,
679,
96,
185,
830,
537,
428,
52,
1088,
595,
980,
523,
435,
444,
825,
1036,
1004,
1076,
701,
1023,
389,
657,
548,
317,
914,
475,
685,
533,
984,
222,
107,
893,
869,
186,
20,
102,
928,
246,
89,
1056,
524,
113,
164,
418,
393,
36,
1027,
961,
768,
538,
285,
1000,
700,
262,
993,
726,
1031,
819
]
}
================================================
FILE: datainfo/circular_motion/data_split_dict_3.json
================================================
{
"test": [
887,
659,
395,
532,
890,
471,
385,
386,
882,
918,
141,
981,
368,
321,
347,
888,
1003,
43,
706,
159,
269,
625,
298,
417,
994,
202,
971,
32,
548,
614,
110,
78,
7,
317,
1065,
483,
571,
677,
773,
92,
90,
243,
850,
1082,
601,
41,
1085,
840,
136,
704,
181,
990,
987,
129,
254,
583,
546,
432,
213,
668,
334,
572,
58,
689,
475,
834,
718,
790,
1038,
862,
1076,
893,
528,
444,
1013,
278,
73,
199,
748,
274,
910,
808,
874,
793,
968,
551,
63,
616,
87,
326,
131,
31,
798,
1071,
310,
474,
308,
813,
975,
963,
392,
479,
531,
960,
26,
134,
970,
757,
267,
487
],
"val": [
223,
897,
759,
344,
81,
173,
876,
829,
448,
229,
853,
691,
341,
329,
930,
104,
99,
714,
665,
445,
694,
191,
335,
244,
275,
823,
669,
227,
512,
1022,
1094,
752,
700,
582,
27,
832,
1056,
107,
996,
806,
307,
270,
988,
864,
936,
776,
320,
189,
372,
1039,
327,
615,
606,
305,
467,
643,
257,
378,
21,
692,
62,
974,
22,
501,
755,
285,
723,
623,
950,
939,
695,
361,
477,
340,
642,
648,
61,
1037,
647,
603,
630,
995,
20,
322,
593,
425,
807,
11,
1005,
561,
1083,
533,
264,
447,
1035,
958,
1030,
732,
737,
649,
441,
277,
519,
830,
1088,
635,
105,
1050,
697,
609
],
"train": [
401,
148,
464,
711,
957,
1084,
420,
1007,
318,
1074,
579,
972,
434,
2,
770,
220,
1010,
540,
192,
740,
857,
788,
545,
408,
352,
163,
469,
118,
758,
534,
506,
587,
504,
396,
863,
760,
231,
794,
1059,
1049,
768,
655,
456,
279,
313,
459,
769,
371,
80,
747,
66,
67,
525,
804,
521,
144,
1081,
894,
1060,
1066,
682,
439,
476,
142,
576,
620,
38,
672,
909,
345,
328,
1019,
193,
1087,
713,
179,
413,
156,
607,
640,
898,
1046,
511,
1027,
253,
558,
135,
781,
1098,
599,
945,
1073,
157,
905,
221,
197,
177,
59,
219,
154,
178,
486,
421,
727,
64,
53,
402,
1097,
355,
751,
228,
646,
631,
852,
452,
627,
1052,
309,
292,
462,
325,
784,
555,
83,
225,
149,
350,
1025,
1031,
468,
427,
969,
49,
846,
836,
262,
473,
1012,
520,
12,
573,
60,
212,
1048,
973,
820,
671,
374,
460,
249,
859,
171,
172,
198,
218,
407,
746,
908,
502,
112,
1036,
450,
594,
1089,
999,
931,
809,
301,
137,
236,
819,
639,
753,
843,
1002,
1069,
75,
207,
100,
214,
300,
346,
720,
8,
578,
89,
188,
272,
1070,
418,
84,
725,
167,
238,
875,
800,
470,
390,
1055,
835,
211,
312,
217,
656,
0,
405,
151,
941,
1026,
174,
375,
586,
304,
814,
315,
265,
855,
1072,
929,
143,
565,
778,
485,
779,
337,
201,
562,
399,
186,
709,
500,
608,
1080,
362,
29,
743,
621,
42,
48,
121,
949,
1040,
696,
266,
557,
685,
6,
690,
624,
10,
1044,
955,
952,
719,
424,
16,
680,
162,
415,
287,
575,
821,
411,
363,
780,
503,
754,
71,
14,
86,
935,
702,
394,
299,
717,
721,
921,
589,
977,
693,
658,
889,
916,
496,
114,
296,
185,
938,
842,
712,
792,
772,
19,
233,
454,
224,
1018,
896,
789,
966,
155,
698,
106,
216,
687,
872,
242,
449,
404,
82,
535,
480,
209,
998,
482,
629,
248,
113,
1041,
705,
1016,
673,
775,
40,
17,
370,
1034,
68,
841,
168,
165,
507,
9,
801,
762,
39,
259,
364,
316,
633,
803,
25,
319,
666,
239,
708,
914,
728,
145,
1057,
412,
815,
472,
518,
699,
108,
619,
397,
735,
564,
652,
684,
745,
251,
119,
338,
951,
943,
379,
97,
797,
387,
559,
1061,
466,
458,
351,
844,
1032,
324,
158,
679,
282,
161,
443,
208,
828,
15,
899,
605,
526,
103,
670,
653,
1045,
543,
451,
517,
597,
724,
976,
303,
65,
516,
1000,
574,
343,
1047,
290,
96,
357,
674,
416,
654,
667,
288,
617,
1058,
1095,
513,
437,
77,
57,
764,
498,
457,
736,
414,
822,
993,
206,
170,
928,
263,
393,
356,
681,
831,
628,
774,
937,
816,
913,
196,
138,
956,
750,
660,
200,
923,
294,
1064,
283,
510,
1028,
366,
707,
810,
947,
650,
595,
610,
246,
234,
95,
331,
153,
284,
854,
932,
638,
1001,
783,
169,
481,
72,
286,
55,
1075,
817,
901,
911,
455,
926,
91,
76,
741,
373,
489,
756,
777,
380,
505,
927,
967,
596,
367,
536,
180,
560,
252,
891,
1051,
1009,
585,
289,
538,
552,
946,
1024,
917,
644,
730,
591,
877,
377,
686,
895,
256,
989,
115,
442,
1020,
281,
924,
676,
786,
400,
742,
215,
984,
497,
388,
446,
339,
342,
30,
824,
997,
463,
365,
839,
851,
920,
664,
120,
866,
93,
881,
922,
925,
641,
884,
261,
845,
900,
140,
1015,
194,
867,
260,
865,
749,
715,
235,
478,
811,
767,
74,
802,
688,
570,
1090,
210,
491,
622,
147,
1067,
892,
645,
731,
176,
205,
391,
490,
1068,
23,
314,
222,
598,
791,
302,
550,
237,
703,
744,
150,
965,
273,
139,
190,
733,
453,
612,
611,
509,
98,
166,
152,
410,
837,
495,
529,
983,
878,
964,
164,
422,
240,
907,
661,
85,
701,
569,
358,
406,
613,
376,
785,
403,
959,
465,
255,
636,
247,
79,
184,
954,
332,
563,
782,
566,
102,
567,
109,
796,
541,
523,
663,
37,
1029,
306,
592,
291,
493,
117,
1093,
683,
787,
24,
183,
384,
101,
675,
94,
761,
45,
886,
241,
146,
1096,
5,
1033,
985,
544,
577,
1078,
111,
1092,
871,
795,
726,
860,
124,
268,
982,
383,
333,
28,
133,
436,
1004,
435,
738,
604,
276,
1017,
879,
849,
18,
716,
428,
430,
739,
508,
953,
381,
915,
979,
833,
44,
651,
722,
554,
880,
618,
1063,
382,
1006,
182,
580,
883,
389,
3,
568,
130,
126,
438,
336,
870,
271,
369,
311,
600,
398,
662,
991,
359,
1023,
869,
160,
323,
942,
514,
527,
88,
729,
827,
494,
70,
766,
127,
47,
1042,
56,
1,
330,
484,
52,
433,
992,
539,
632,
903,
1043,
848,
51,
409,
584,
934,
499,
1021,
1079,
280,
245,
799,
1086,
547,
125,
1014,
226,
360,
948,
1054,
553,
258,
128,
818,
116,
626,
54,
838,
549,
537,
961,
904,
902,
919,
515,
175,
873,
861,
492,
13,
50,
590,
440,
203,
524,
710,
35,
46,
250,
122,
293,
602,
69,
232,
349,
556,
295,
1091,
765,
678,
771,
933,
763,
33,
906,
734,
986,
1062,
522,
637,
488,
4,
204,
1008,
423,
36,
419,
581,
429,
297,
426,
978,
354,
1053,
885,
812,
530,
1011,
431,
132,
940,
353,
634,
825,
1077,
657,
847,
868,
348,
944,
187,
588,
858,
856,
826,
962,
195,
542,
34,
123,
805,
230,
980,
461,
912
]
}
================================================
FILE: datainfo/collect/circular_motion/make_data.py
================================================
import numpy as np
import os
from tqdm import tqdm
import shutil
from PIL import Image, ImageDraw
def mkdir(folder):
if os.path.exists(folder):
shutil.rmtree(folder)
os.makedirs(folder)
# background and object properties
img_size = (128, 128)
bg_color = 'gray'
obj_size = 8
obj_shape = 'ball'
obj_color ='blue'
def coord2img(x, y):
img = Image.new('RGB', img_size, bg_color)
draw = ImageDraw.Draw(img)
pos_x = int(x * img_size[0])
pos_y = int((1-y) * img_size[1])
pos = (pos_x-obj_size, pos_y-obj_size, pos_x+obj_size, pos_y+obj_size)
if obj_shape == 'square':
draw.rectangle(pos, fill=obj_color)
elif obj_shape == 'ball':
draw.ellipse(pos, fill=obj_color)
else:
assert False
return img
data_filepath = './circular_motion'
num_exp = 1100
num_frames = 60
# data_filepath = './circular_motion_long'
# num_exp = 1
# num_frames = 60*20
center = (0.5, 0.5)
radius = 0.3
np.random.seed(0)
for p_exp in tqdm(range(num_exp)):
theta = np.random.uniform(0, 2*np.pi)
theta_d = np.random.choice((-1, 1)) * np.random.uniform(np.pi/30, np.pi/10)
seq_filepath = os.path.join(data_filepath, str(p_exp))
mkdir(seq_filepath)
for p_frame in range(num_frames):
x = center[0] + radius * np.cos(theta)
y = center[1] + radius * np.sin(theta)
img = coord2img(x, y)
theta += theta_d
img.save(os.path.join(seq_filepath, str(p_frame)+'.png'))
================================================
FILE: datainfo/collect/double_pendulum/README.md
================================================
## Double Pendulum Data Collection
Suppose the raw videos recording the double pendulum motion are stored in ```./visphysics_data/double_pendulum_raw_videos``` in the current folder.
Run the following commands.
The collected data are saved in ```./visphysics_data/double_pendulum```.
1. python convert_video.py -f 60 (convert each video to a sequence of images with sampling frequency 60fps)
2. python equalize_background.py -r 215 -g 205 -b 192 (equalize the background color)
3. python split_data.py (split each sequence
into subsequences with fixed length 60)
================================================
FILE: datainfo/collect/double_pendulum/convert_video.py
================================================
import cv2
import os
import shutil
from tqdm import tqdm
import argparse
import random
import time
## This script is for the purposes of extracting frames from given set of input videos
## at desired fps.
def mkdir(folder):
if os.path.exists(folder):
shutil.rmtree(folder)
os.makedirs(folder)
def get_arguments():
ap = argparse.ArgumentParser()
ap.add_argument('-f', '--fps', required=True, type=int,
help='frames per second at which it is desirable to get output data')
ap.add_argument('-s', '--seed', default=0, type=int,
help='random seed for shuffling')
args = ap.parse_args()
return args
def write_frames(video_path, frame_path, fps, seed):
# get list of all videos in video_path
listing = os.listdir(video_path)
# remove .DS_Store files
if listing.__contains__('.DS_Store'):
listing.remove('.DS_Store')
# random suffle
random.seed(seed)
random.shuffle(listing)
num_videos = len(listing)
for (idx, filename) in enumerate(listing):
print(idx+1, 'out of ', num_videos)
cap = cv2.VideoCapture(os.path.join(video_path, filename))
video_fps = cap.get(cv2.CAP_PROP_FPS)
print('video fps: ', video_fps)
data_path = os.path.join(frame_path, str(idx))
mkdir(data_path)
p_frame = 0 # frame count
skip = round(video_fps / fps) # sampling frequency
while(True):
# read current frame
ret, frame = cap.read()
if ret == False:
break
# manual crop
frame = frame[0:620, 371:991]
# downsize
frame = cv2.resize(frame, (128, 128), interpolation = cv2.INTER_AREA)
if p_frame % skip == 0:
cv2.imwrite(os.path.join(data_path, str(int(p_frame/skip))+'.png'), frame)
p_frame += 1
cap.release()
cv2.destroyAllWindows()
time.sleep(1.0) # short delay for cleaning purposes
video_path = "./visphysics_data/double_pendulum_raw_videos"
frame_path = "./visphysics_data/double_pendulum_raw_data"
args = get_arguments()
write_frames(video_path, frame_path, args.fps, args.seed)
================================================
FILE: datainfo/collect/double_pendulum/equalize_background.py
================================================
import cv2
import os
import shutil
from tqdm import tqdm
import argparse
import numpy as np
'''
This script equalizes backgrounds of all frames with a uniform background.
The background color is user-specified or computed by averaging
background pixels of sampled frames from the data.
'''
def mkdir(folder):
if os.path.exists(folder):
shutil.rmtree(folder)
os.makedirs(folder)
def get_arguments():
ap = argparse.ArgumentParser()
ap.add_argument('-b', '--channel0', help='b channel')
ap.add_argument('-g', '--channel1', help='g channel')
ap.add_argument('-r', '--channel2', help='r channel')
args = ap.parse_args()
return args
raw_filepath = "./visphysics_data/double_pendulum_raw_data"
data_filepath = "./visphysics_data/double_pendulum_background_corrected_data"
num_exps = len(os.listdir(raw_filepath))
bg = np.zeros(3)
args = get_arguments()
if args.channel0 is not None and args.channel1 is not None and args.channel2 is not None:
# user-specified background color
bg[0] = args.channel0
bg[1] = args.channel1
bg[2] = args.channel2
else:
# uniform background by averaging sampled pixels
percent = .05 # proportion of sampled frames
cnt = 0
print('Computing uniform background...')
for p_exp in tqdm(range(num_exps)):
seq_filepath = os.path.join(raw_filepath, str(p_exp))
num_frames = len(os.listdir(seq_filepath))
num_samples = int(percent * num_frames)
for p_frame in range(num_samples):
# read each frame
frame_path = os.path.join(seq_filepath, str(p_frame)+'.png')
frame = cv2.imread(frame_path)
# record all background pixels
for i in range(frame.shape[0]):
for j in range(frame.shape[1]):
# manual threshold
if (frame[i,j,2] > 130):
bg += frame[i,j]
cnt += 1
# average over all sampled pixels
bg = bg / cnt
print('Uniform Background R: ', bg[2], ' G: ', bg[1], ' B: ', bg[0])
# apply the background to all data
print('Applying the background...')
for p_exp in tqdm(range(num_exps)):
seq_filepath = os.path.join(raw_filepath, str(p_exp))
data_seq_filepath = os.path.join(data_filepath, str(p_exp))
mkdir(data_seq_filepath)
num_frames = len(os.listdir(seq_filepath))
for p_frame in range(num_frames):
frame_path = os.path.join(seq_filepath, str(p_frame)+'.png')
frame = cv2.imread(frame_path)
for i in range(frame.shape[0]):
for j in range(frame.shape[1]):
# manual threshold
if (frame[i,j,2] > 130):
frame[i,j] = bg
# write to data
cv2.imwrite(os.path.join(data_seq_filepath, str(p_frame)+'.png'), frame)
================================================
FILE: datainfo/collect/double_pendulum/split_data.py
================================================
import os
import shutil
from tqdm import tqdm
import argparse
'''
This script splits each sequence, which corresponds to a single experiment,
into subsequences with fixed length.
'''
def mkdir(folder):
if os.path.exists(folder):
shutil.rmtree(folder)
os.makedirs(folder)
def get_arguments():
ap = argparse.ArgumentParser()
ap.add_argument('-l', '--len', type=int, default=60, help='length of each subsequence')
args = ap.parse_args()
return args
src_filepath = "./visphysics_data/double_pendulum_background_corrected_data"
dst_filepath = "./visphysics_data/double_pendulum"
args = get_arguments()
num_exps = len(os.listdir(src_filepath))
subseq_cnt = 0
frame_cnt = 0
mkdir(os.path.join(dst_filepath, str(subseq_cnt)))
for p_exp in tqdm(range(num_exps)):
src_seq_filepath = os.path.join(src_filepath, str(p_exp))
num_frames = len(os.listdir(src_seq_filepath))
# ignore frames at the end of the sequence
num_frames = int(num_frames/args.len) * args.len
for p_frame in range(num_frames):
src_path = os.path.join(src_seq_filepath, str(p_frame)+'.png')
dst_path = os.path.join(dst_filepath, str(subseq_cnt), str(frame_cnt)+'.png')
shutil.copyfile(src_path, dst_path)
frame_cnt += 1
if frame_cnt == args.len:
subseq_cnt += 1
frame_cnt = 0
mkdir(os.path.join(dst_filepath, str(subseq_cnt)))
# remove the last empty folder
shutil.rmtree(os.path.join(dst_filepath, str(subseq_cnt)))
================================================
FILE: datainfo/collect/elastic_pendulum/make_data.py
================================================
import os
import shutil
import numpy as np
from tqdm import tqdm
from PIL import Image, ImageDraw
from scipy import linalg
from scipy.integrate import solve_ivp
def mkdir(folder):
if os.path.exists(folder):
shutil.rmtree(folder)
os.makedirs(folder)
def engine(rng, num_frm, fps=60):
# physical parameters
L_0 = 0.205 # elastic pendulum unstretched length (m)
L = 0.179 # rigid pendulum rod length (m)
w = 0.038 # rigid pendulum rod width (m)
m = 0.110 # rigid pendulum mass (kg)
g = 9.81 # gravitational acceleration (m/s^2)
k = 40.0 # elastic constant (kg/s^2)
dt = 1.0 / fps
t_eval = np.arange(num_frm) * dt
# solve equations of motion
# y = [theta_1, theta_2, z, vel_theta_1, vel_theta_2, vel_z]
def f(t, y):
La = L_0 + y[2]
Lb = 0.5 * L
I = (1. / 12.) * m * (w ** 2 + L ** 2)
Jc = L * np.cos(y[0] - y[1])
Js = L * np.sin(y[0] - y[1])
A = np.array([[La**2, 0.5*La*Jc, 0],
[0.5*La*Jc, Lb**2+I/m, 0.5*Js],
[0, 0.5*Js, 1]])
b = np.zeros(3)
b[0] = -0.5*La*Js*y[4]**2 - 2*La*y[3]*y[5] - g*La*np.sin(y[0])
b[1] = 0.5*La*Js*y[3]**2 - Jc*y[3]*y[5] - g*Lb*np.sin(y[1])
b[2] = La*y[3]**2 + 0.5*Jc*y[4]**2 + g*np.cos(y[0]) - (k/m)*y[2]
sol = linalg.solve(A, b)
return [y[3], y[4], y[5], sol[0], sol[1], sol[2]]
# run until the drawn pendulums are inside the image
rej = True
while rej:
initial_state = [rng.uniform(0, 2*np.pi), rng.uniform(0, 2*np.pi), rng.uniform(-0.04, 0.04), rng.uniform(-6, 6), rng.uniform(-10, 10), 0]
sol = solve_ivp(f, [t_eval[0], t_eval[-1]], initial_state, t_eval=t_eval, rtol=1e-6)
states = sol.y.T
lim_x = np.abs((L_0+states[:, 2])*np.sin(states[:, 0]) + L*np.sin(states[:, 1])) - (L_0 + L)
lim_y = np.abs((L_0+states[:, 2])*np.cos(states[:, 0]) + L*np.cos(states[:, 1])) - (L_0 + L)
rej = (np.max(lim_x) > 0.16) | (np.max(lim_y) > 0.16) | (np.min(states[:, 2]) < -0.08)
return states
def draw_rect(im, col, top_x, top_y, w, h, theta):
x1 = top_x - w * np.cos(theta) / 2
y1 = top_y + w * np.sin(theta) / 2
x2 = x1 + w * np.cos(theta)
y2 = y1 - w * np.sin(theta)
x3 = x2 + h * np.sin(theta)
y3 = y2 + h * np.cos(theta)
x4 = x3 - w * np.cos(theta)
y4 = y3 + w * np.sin(theta)
pts = [(x1, y1), (x2, y2), (x3, y3), (x4, y4)]
draw = ImageDraw.Draw(im)
draw.polygon(pts, fill=col)
def render(theta_1, theta_2, z):
bg_color = (215, 205, 192)
pd1_color = (63, 66, 85)
pd2_color = (17, 93, 234)
im = Image.new('RGB', (1600, 1600), bg_color)
center = (800, 800)
w1, w2, h1, h2 = 90, 90, 300, 250
L_0 = 0.205
h1 *= (1 + z / L_0)
pd1_end_x = center[0] + (h1-35) * np.sin(theta_1)
pd1_end_y = center[1] + (h1-35) * np.cos(theta_1)
# pd1 may hide pd2
draw_rect(im, pd2_color, pd1_end_x, pd1_end_y, w2, h2, theta_2)
draw_rect(im, pd1_color, center[0], center[1], w1, h1, theta_1)
im = im.resize((128, 128))
return im
def make_data(data_filepath, num_seq, num_frm, seed=0):
mkdir(data_filepath)
rng = np.random.default_rng(seed)
states = np.zeros((num_seq, num_frm, 6))
for n in tqdm(range(num_seq)):
seq_filepath = os.path.join(data_filepath, str(n))
mkdir(seq_filepath)
states[n, :, :] = engine(rng, num_frm)
for k in range(num_frm):
im = render(states[n, k, 0], states[n, k, 1], states[n, k, 2])
im.save(os.path.join(seq_filepath, str(k)+'.png'))
np.save(os.path.join(data_filepath, 'states.npy'), states)
if __name__ == '__main__':
data_filepath = '/data/kuang/visphysics_data/elastic_pendulum'
make_data(data_filepath, num_seq=1200, num_frm=60)
================================================
FILE: datainfo/collect/fire/split_data.py
================================================
import os
import shutil
from tqdm import tqdm
import argparse
'''
This script splits each sequence, which corresponds to a single experiment,
into subsequences with fixed length.
'''
def mkdir(folder):
if os.path.exists(folder):
shutil.rmtree(folder)
os.makedirs(folder)
def get_arguments():
ap = argparse.ArgumentParser()
ap.add_argument('-l', '--len', type=int, default=60, help='length of each subsequence')
args = ap.parse_args()
return args
src_train_filepath = "./visphysics_data/fire/train"
src_test_filepath = "./visphysics_data/fire/test"
dst_filepath = "./visphysics_data/fire_60"
args = get_arguments()
num_train_exps = len(os.listdir(src_train_filepath))
num_test_exps = len(os.listdir(src_test_filepath))
subseq_cnt = 0
frame_cnt = 0
mkdir(os.path.join(dst_filepath, str(subseq_cnt)))
# train
for p_exp in tqdm(range(num_train_exps)):
src_seq_filepath = os.path.join(src_train_filepath, str(p_exp))
num_frames = len(os.listdir(src_seq_filepath))
# ignore frames at the end of the sequence
num_frames = int(num_frames/args.len) * args.len
for p_frame in range(num_frames):
src_path = os.path.join(src_seq_filepath, str(p_frame)+'.jpg')
dst_path = os.path.join(dst_filepath, str(subseq_cnt), str(frame_cnt)+'.jpg')
shutil.copyfile(src_path, dst_path)
frame_cnt += 1
if frame_cnt == args.len:
subseq_cnt += 1
frame_cnt = 0
mkdir(os.path.join(dst_filepath, str(subseq_cnt)))
# test
for p_exp in tqdm(range(num_test_exps)):
src_seq_filepath = os.path.join(src_test_filepath, str(p_exp))
num_frames = len(os.listdir(src_seq_filepath))
# ignore frames at the end of the sequence
num_frames = int(num_frames/args.len) * args.len
for p_frame in range(num_frames):
src_path = os.path.join(src_seq_filepath, str(p_frame)+'.png')
dst_path = os.path.join(dst_filepath, str(subseq_cnt), str(frame_cnt)+'.png')
shutil.copyfile(src_path, dst_path)
frame_cnt += 1
if frame_cnt == args.len:
subseq_cnt += 1
frame_cnt = 0
mkdir(os.path.join(dst_filepath, str(subseq_cnt)))
# remove the last empty folder
shutil.rmtree(os.path.join(dst_filepath, str(subseq_cnt)))
================================================
FILE: datainfo/collect/lava_lamp/split_data.py
================================================
import os
import shutil
from tqdm import tqdm
import argparse
'''
This script splits each sequence, which corresponds to a single experiment,
into subsequences with fixed length.
'''
def mkdir(folder):
if os.path.exists(folder):
shutil.rmtree(folder)
os.makedirs(folder)
def get_arguments():
ap = argparse.ArgumentParser()
ap.add_argument('-l', '--len', type=int, default=60, help='length of each subsequence')
args = ap.parse_args()
return args
src_train_filepath = "./visphysics_data/lavalamp/raw_data_train"
src_test_filepath = "./visphysics_data/lavalamp/raw_data_test"
dst_filepath = "./visphysics_data/lavalamp_60"
args = get_arguments()
num_train_exps = len(os.listdir(src_train_filepath))
num_test_exps = len(os.listdir(src_test_filepath))
subseq_cnt = 0
frame_cnt = 0
mkdir(os.path.join(dst_filepath, str(subseq_cnt)))
# train
src_seq_filepath = src_train_filepath
num_frames = len(os.listdir(src_seq_filepath))
# ignore frames at the end of the sequence
num_frames = int(num_frames/args.len) * args.len
for p_frame in range(num_frames):
src_path = os.path.join(src_seq_filepath, str(p_frame)+'.jpg')
dst_path = os.path.join(dst_filepath, str(subseq_cnt), str(frame_cnt)+'.jpg')
shutil.copyfile(src_path, dst_path)
frame_cnt += 1
if frame_cnt == args.len:
subseq_cnt += 1
frame_cnt = 0
mkdir(os.path.join(dst_filepath, str(subseq_cnt)))
# test
src_seq_filepath = src_test_filepath
num_frames = len(os.listdir(src_seq_filepath))
# ignore frames at the end of the sequence
num_frames = int(num_frames/args.len) * args.len
for p_frame in range(num_frames):
src_path = os.path.join(src_seq_filepath, str(p_frame)+'.png')
dst_path = os.path.join(dst_filepath, str(subseq_cnt), str(frame_cnt)+'.png')
shutil.copyfile(src_path, dst_path)
frame_cnt += 1
if frame_cnt == args.len:
subseq_cnt += 1
frame_cnt = 0
mkdir(os.path.join(dst_filepath, str(subseq_cnt)))
# remove the last empty folder
shutil.rmtree(os.path.join(dst_filepath, str(subseq_cnt)))
================================================
FILE: datainfo/collect/reaction_diffusion/make_data.py
================================================
import os
import shutil
from tqdm import tqdm
import numpy as np
import scipy.io as sio
from PIL import Image
from matplotlib import cm
def mkdir(folder):
if os.path.exists(folder):
shutil.rmtree(folder)
os.makedirs(folder)
def make_data(data_filepath, num_frm):
data = sio.loadmat('reaction_diffusion.mat')
mkdir(data_filepath)
num_seq = int(data['t'].size / num_frm)
u_min, u_max = -1, 1
for n in tqdm(range(num_seq)):
seq_filepath = os.path.join(data_filepath, str(n))
mkdir(seq_filepath)
for k in range(num_frm):
u = data['uf'][:, :, n*num_frm+k]
u = (u - u_min) / (u_max - u_min)
u = cm.viridis(u)[:,:,:3]
u = np.uint8(u*255)
im = Image.fromarray(u)
im.save(os.path.join(seq_filepath, str(k)+'.png'))
if __name__ == '__main__':
data_filepath = 'reaction_diffusion'
make_data(data_filepath, 100)
================================================
FILE: datainfo/collect/reaction_diffusion/reaction_diffusion.m
================================================
clear all; close all; clc
% lambda-omega reaction-diffusion system
% u_t = lam(A) u - ome(A) v + d1*(u_xx + u_yy) = 0
% v_t = ome(A) u + lam(A) v + d2*(v_xx + v_yy) = 0
%
% A^2 = u^2 + v^2 and
% lam(A) = 1 - A^2
% ome(A) = -beta*A^2
t=0:0.2:2000.2;
d1=0.1; d2=0.1; beta=1.0;
L=20; n=128; N=n*n;
x2=linspace(-L/2,L/2,n+1); x=x2(1:n); y=x;
kx=(2*pi/L)*[0:(n/2-1) -n/2:-1]; ky=kx;
% INITIAL CONDITIONS
[X,Y]=meshgrid(x,y);
[KX,KY]=meshgrid(kx,ky);
K2=KX.^2+KY.^2; K22=reshape(K2,N,1);
m=1; % number of spirals
f = exp(-.01*(X.^2+Y.^2)); % circular mask to get rid of boundaries
u_ini=tanh(sqrt(X.^2+Y.^2)).*cos(m*angle(X+i*Y)-(sqrt(X.^2+Y.^2)));
v_ini=tanh(sqrt(X.^2+Y.^2)).*sin(m*angle(X+i*Y)-(sqrt(X.^2+Y.^2)));
% uf(:,:,1)=f.*u(:,:,1);
% vf(:,:,1)=f.*v(:,:,1);
% REACTION-DIFFUSION
uvt=[reshape(fft2(u_ini),1,N) reshape(fft2(v_ini),1,N)].';
uvt_rhs = reaction_diffusion_rhs(t(1),uvt,[],K22,d1,d2,beta,n,N);
% du(:,:,1) = real(ifft2(reshape(uvt_rhs(1:N).',n,n)));
% dv(:,:,1)= real(ifft2(reshape(uvt_rhs((N+1):(2*N)).',n,n)));
[t,uvsol]=ode45('reaction_diffusion_rhs',t,uvt,[],K22,d1,d2,beta,n,N);
u = zeros(length(x),length(y),length(t));
v = zeros(length(x),length(y),length(t));
uf = zeros(length(x), length(y), length(t));
vf = zeros(length(x), length(y), length(t));
du = zeros(length(x),length(y),length(t));
dv = zeros(length(x),length(y),length(t));
duf = zeros(length(x),length(y),length(t));
dvf = zeros(length(x),length(y),length(t));
for j=1:length(t)-1
ut=reshape((uvsol(j,1:N).'),n,n);
vt=reshape((uvsol(j,(N+1):(2*N)).'),n,n);
u(:,:,j+1)=real(ifft2(ut));
v(:,:,j+1)=real(ifft2(vt));
uvt_rhs = reaction_diffusion_rhs(t(j+1),uvsol(j,1:end).',[],K22,d1,d2,beta,n,N);
du(:,:,j+1)= real(ifft2(reshape(uvt_rhs(1:N).',n,n)));
dv(:,:,j+1)= real(ifft2(reshape(uvt_rhs((N+1):(2*N)).',n,n)));
uf(:,:,j+1)=f.*u(:,:,j+1);
vf(:,:,j+1)=f.*v(:,:,j+1);
duf(:,:,j+1)=f.*du(:,:,j+1);
dvf(:,:,j+1)=f.*dv(:,:,j+1);
% figure(1)
% pcolor(x,y,v(:,:,j+1)); shading interp; colormap(hot); colorbar; drawnow;
end
t = t(3:end);
uf = uf(:,:,3:end);
vf = vf(:,:,3:end);
duf = duf(:,:,3:end);
dvf = dvf(:,:,3:end);
save('reaction_diffusion.mat','t','x','y','uf')
%%
% load reaction_diffusion_big
% pcolor(x,y,u(:,:,end)); shading interp; colormap(hot)
================================================
FILE: datainfo/collect/reaction_diffusion/reaction_diffusion_rhs.m
================================================
function rhs=reaction_diffusion_rhs(t,uvt,dummy,K22,d1,d2,beta,n,N);
% Calculate u and v terms
ut=reshape((uvt(1:N)),n,n);
vt=reshape((uvt((N+1):(2*N))),n,n);
u=real(ifft2(ut)); v=real(ifft2(vt));
% Reaction Terms
u3=u.^3; v3=v.^3; u2v=(u.^2).*v; uv2=u.*(v.^2);
utrhs=reshape((fft2(u-u3-uv2+beta*u2v+beta*v3)),N,1);
vtrhs=reshape((fft2(v-u2v-v3-beta*u3-beta*uv2)),N,1);
rhs=[-d1*K22.*uvt(1:N)+utrhs
-d2*K22.*uvt(N+1:end)+vtrhs];
================================================
FILE: datainfo/collect/single_pendulum/make_data.py
================================================
import os
import shutil
import numpy as np
from tqdm import tqdm
from PIL import Image
from scipy.integrate import solve_ivp
def mkdir(folder):
if os.path.exists(folder):
shutil.rmtree(folder)
os.makedirs(folder)
def engine(rng, num_frm, fps=60):
# parameters
m = 1.0
l = 0.5
g = 9.81
dt = 1.0 / fps
t_eval = np.arange(num_frm) * dt
# solve equations of motion
# y = [theta, vel_theta]
f = lambda t, y: [y[1], -3*g/(2*l) * np.sin(y[0])]
initial_state = [rng.uniform(0, 2*np.pi), rng.uniform(-10, 10)]
sol = solve_ivp(f, [t_eval[0], t_eval[-1]], initial_state, t_eval=t_eval, rtol=1e-6)
states = sol.y.T
return states
def render(theta):
bg_color = (215, 205, 192)
im = Image.new('RGB', (800, 800), bg_color)
pendulum_im = Image.open('./pendulum.png')
im.paste(pendulum_im, (363, 400))
im = im.rotate(theta*180/np.pi, fillcolor=bg_color)
im = im.resize((128,128))
return im
def make_data(data_filepath, num_seq, num_frm, seed=0):
mkdir(data_filepath)
rng = np.random.default_rng(seed)
states = np.zeros((num_seq, num_frm, 2))
for n in tqdm(range(num_seq)):
seq_filepath = os.path.join(data_filepath, str(n))
mkdir(seq_filepath)
states[n, :, :] = engine(rng, num_frm)
for k in range(num_frm):
im = render(states[n, k, 0])
im.save(os.path.join(seq_filepath, str(k)+'.png'))
np.save(os.path.join(data_filepath, 'states.npy'), states)
if __name__ == '__main__':
data_filepath = '/data/kuang/visphysics_data/single_pendulum'
make_data(data_filepath, num_seq=1200, num_frm=60)
================================================
FILE: datainfo/collect/utils/sort_vids.py
================================================
import os
import shutil
import numpy as np
from tqdm import tqdm
def mkdir(folder):
if os.path.exists(folder):
shutil.rmtree(folder)
os.makedirs(folder)
data_folder = '/home/cml/Downloads/visphysics_data'
object_lst = ['air_dancer', 'fire', 'swingstick_magnetic', 'swingstick_non_magnetic']
num_train_vids = [22, 1, 18, 68]
num_test_vids = [5, 1, 5, 17]
new_data_folder= '/home/cml/Downloads/visphysics_data_new'
for i in tqdm(range(len(object_lst))):
obj = object_lst[i]
old_folder = os.path.join(data_folder, obj)
new_folder = os.path.join(new_data_folder, obj)
# train
num_train_frames = len(os.listdir(os.path.join(old_folder, 'raw_data_train')))
ids = np.array(list(range(num_train_frames)))
ids = np.split(ids, num_train_vids[i])
for j in tqdm(range(num_train_vids[i])):
new_vid_folder = os.path.join(new_folder, 'train', str(j))
mkdir(new_vid_folder)
this_ids = ids[j]
k = 0
for p_idx in this_ids:
src = os.path.join(old_folder, 'raw_data_train', str(p_idx) + '.jpg')
dest = os.path.join(new_vid_folder, str(k) + '.jpg')
shutil.copy(src, dest)
k = k + 1
# test
num_test_frames = len(os.listdir(os.path.join(old_folder, 'raw_data_test')))
ids = np.array(list(range(num_test_frames)))
ids = np.split(ids, num_test_vids[i])
for j in tqdm(range(num_test_vids[i])):
new_vid_folder = os.path.join(new_folder, 'test', str(j))
mkdir(new_vid_folder)
this_ids = ids[j]
k = 0
for p_idx in this_ids:
src = os.path.join(old_folder, 'raw_data_test', str(p_idx) + '.png')
dest = os.path.join(new_vid_folder, str(k) + '.png')
shutil.copy(src, dest)
k = k + 1
================================================
FILE: datainfo/double_pendulum/data_split_dict_1.json
================================================
{
"test": [
524,
1032,
960,
907,
977,
840,
877,
802,
392,
5,
746,
980,
623,
1068,
675,
1098,
931,
470,
361,
591,
867,
975,
352,
526,
414,
237,
561,
880,
942,
552,
1059,
789,
12,
1005,
464,
345,
348,
806,
631,
89,
961,
60,
1002,
758,
1046,
335,
221,
898,
177,
767,
751,
354,
848,
827,
497,
983,
70,
805,
1034,
1022,
581,
621,
388,
1039,
1072,
1025,
681,
247,
607,
380,
204,
852,
44,
593,
941,
448,
472,
707,
477,
1015,
896,
454,
59,
864,
443,
780,
18,
52,
45,
62,
650,
209,
468,
545,
912,
4,
886,
798,
58,
999,
192,
429,
777,
967,
920,
1014,
241,
522,
129,
275
],
"val": [
456,
164,
736,
36,
149,
406,
1075,
230,
21,
1017,
442,
620,
214,
1096,
747,
259,
111,
264,
1089,
815,
431,
351,
395,
319,
24,
116,
485,
508,
329,
719,
465,
301,
728,
663,
279,
672,
172,
540,
1063,
163,
171,
71,
297,
1011,
189,
639,
944,
112,
1004,
255,
287,
773,
772,
14,
463,
17,
888,
85,
72,
689,
861,
33,
261,
836,
871,
816,
564,
817,
93,
881,
185,
598,
563,
181,
1079,
235,
823,
28,
614,
469,
339,
627,
1033,
638,
553,
551,
1,
1040,
424,
365,
832,
496,
423,
516,
963,
1019,
567,
583,
373,
890,
492,
57,
972,
436,
210,
574,
796,
531,
132,
828
],
"train": [
391,
479,
51,
1077,
882,
625,
268,
197,
433,
635,
246,
685,
640,
803,
131,
398,
1070,
580,
924,
863,
962,
238,
39,
763,
653,
455,
608,
869,
680,
768,
634,
107,
514,
509,
884,
417,
219,
146,
311,
630,
996,
290,
80,
911,
366,
353,
130,
814,
1001,
765,
100,
565,
487,
498,
158,
147,
862,
887,
13,
529,
295,
444,
916,
795,
722,
260,
1016,
950,
609,
284,
895,
542,
535,
381,
450,
294,
750,
502,
494,
326,
556,
67,
1057,
384,
933,
143,
644,
362,
1084,
120,
16,
1094,
706,
788,
233,
905,
1036,
596,
733,
853,
418,
254,
119,
26,
559,
575,
549,
1088,
834,
966,
202,
434,
1095,
95,
308,
507,
590,
154,
988,
357,
1008,
425,
360,
234,
778,
819,
883,
90,
793,
611,
1028,
901,
812,
740,
432,
923,
68,
289,
167,
679,
784,
1093,
186,
671,
891,
849,
316,
729,
669,
213,
375,
801,
304,
991,
569,
9,
633,
603,
428,
190,
940,
642,
368,
771,
857,
282,
263,
419,
330,
1031,
420,
655,
876,
520,
1041,
544,
194,
791,
81,
122,
859,
3,
656,
1042,
201,
956,
753,
266,
262,
782,
243,
334,
54,
99,
341,
140,
280,
1027,
421,
787,
43,
401,
96,
115,
989,
1071,
973,
56,
734,
1061,
617,
850,
30,
928,
534,
573,
1067,
968,
949,
499,
343,
868,
990,
837,
242,
1092,
482,
1083,
992,
568,
350,
723,
969,
612,
467,
344,
810,
503,
1087,
693,
807,
619,
174,
717,
1076,
594,
29,
624,
586,
718,
336,
809,
926,
616,
441,
409,
1045,
1053,
954,
134,
451,
501,
537,
731,
533,
1010,
1006,
491,
708,
206,
688,
125,
927,
113,
770,
459,
709,
453,
390,
200,
713,
402,
906,
152,
538,
25,
379,
654,
766,
188,
267,
493,
661,
137,
776,
121,
764,
687,
430,
978,
570,
800,
47,
82,
701,
489,
959,
396,
762,
1020,
35,
216,
939,
50,
1007,
1026,
23,
948,
702,
438,
248,
987,
878,
77,
155,
278,
745,
135,
178,
998,
309,
37,
676,
759,
775,
870,
79,
1048,
539,
712,
699,
148,
412,
613,
986,
385,
34,
446,
484,
915,
207,
774,
979,
199,
657,
378,
269,
184,
371,
271,
856,
845,
394,
532,
359,
997,
673,
476,
739,
478,
236,
821,
515,
372,
665,
865,
695,
831,
965,
285,
879,
892,
958,
331,
1000,
666,
1090,
889,
374,
1091,
974,
276,
156,
179,
585,
643,
226,
935,
0,
714,
781,
790,
510,
1049,
257,
223,
447,
293,
651,
530,
825,
69,
332,
370,
1052,
548,
1065,
525,
6,
382,
291,
474,
726,
413,
55,
415,
452,
104,
196,
292,
756,
519,
1078,
215,
647,
716,
913,
779,
437,
1013,
582,
866,
748,
829,
571,
705,
299,
527,
139,
698,
277,
75,
606,
88,
1066,
838,
427,
124,
648,
321,
15,
725,
296,
486,
270,
97,
2,
286,
383,
193,
127,
808,
151,
338,
854,
483,
1080,
1050,
307,
677,
61,
108,
636,
700,
318,
684,
546,
166,
909,
84,
696,
1035,
462,
602,
481,
195,
760,
860,
123,
473,
435,
227,
517,
358,
900,
337,
495,
715,
315,
397,
785,
142,
936,
737,
622,
769,
660,
1074,
168,
393,
356,
403,
1018,
53,
833,
855,
659,
159,
180,
386,
641,
126,
87,
475,
169,
1058,
599,
844,
902,
208,
543,
918,
377,
523,
73,
595,
224,
281,
1024,
422,
1085,
490,
506,
724,
244,
231,
1064,
512,
953,
541,
952,
7,
994,
239,
662,
176,
410,
407,
730,
914,
572,
749,
141,
363,
63,
19,
841,
903,
1043,
1069,
830,
1038,
405,
11,
1023,
597,
449,
670,
678,
562,
440,
457,
683,
839,
253,
49,
161,
550,
946,
306,
182,
820,
566,
323,
32,
558,
145,
211,
1097,
300,
1012,
109,
312,
938,
144,
153,
183,
1037,
826,
743,
652,
513,
943,
157,
674,
505,
367,
921,
692,
22,
76,
847,
757,
930,
752,
804,
249,
20,
872,
250,
94,
792,
649,
1029,
874,
103,
342,
251,
310,
592,
682,
191,
799,
42,
668,
811,
232,
985,
658,
588,
92,
458,
91,
929,
738,
369,
252,
203,
314,
710,
554,
187,
265,
364,
480,
704,
632,
220,
256,
114,
466,
615,
324,
65,
64,
408,
320,
400,
460,
327,
610,
727,
10,
27,
908,
325,
667,
212,
102,
322,
488,
894,
741,
910,
555,
744,
445,
105,
842,
1056,
697,
919,
1009,
118,
165,
899,
600,
245,
897,
754,
843,
971,
947,
932,
686,
628,
1021,
735,
46,
110,
283,
1081,
893,
786,
577,
302,
993,
273,
83,
579,
229,
951,
584,
846,
824,
601,
629,
117,
349,
851,
150,
389,
74,
416,
40,
858,
813,
945,
1062,
797,
500,
732,
618,
783,
298,
904,
1054,
346,
376,
917,
995,
957,
340,
274,
794,
964,
170,
173,
136,
86,
41,
742,
66,
240,
1082,
970,
547,
703,
922,
560,
1051,
98,
818,
272,
218,
439,
347,
138,
576,
1047,
955,
160,
885,
288,
411,
626,
333,
934,
511,
981,
303,
399,
1055,
106,
504,
198,
605,
1073,
690,
587,
1044,
101,
355,
205,
387,
835,
521,
637,
720,
1086,
175,
471,
976,
222,
604,
38,
984,
8,
133,
258,
578,
426,
162,
761,
873,
317,
78,
937,
313,
48,
217,
128,
305,
755,
1030,
875,
646,
1003,
328,
822,
589,
691,
404,
31,
664,
536,
228,
461,
528,
711,
925,
645,
225,
1060,
557,
982,
694,
518,
721
]
}
================================================
FILE: datainfo/double_pendulum/data_split_dict_2.json
================================================
{
"test": [
24,
688,
255,
176,
368,
371,
58,
32,
777,
734,
919,
433,
61,
901,
966,
215,
844,
250,
272,
874,
139,
534,
772,
108,
937,
896,
698,
232,
605,
1066,
50,
668,
588,
60,
217,
391,
17,
699,
154,
750,
1001,
425,
638,
832,
1032,
621,
633,
982,
1082,
340,
664,
454,
996,
935,
718,
1062,
931,
1057,
1025,
1020,
571,
1003,
511,
944,
818,
330,
1060,
741,
724,
1079,
849,
912,
372,
1052,
736,
1065,
1044,
279,
355,
665,
361,
48,
472,
483,
363,
336,
867,
778,
652,
952,
745,
56,
1090,
549,
1028,
911,
761,
1042,
805,
882,
324,
73,
434,
515,
631,
346,
739,
173,
187,
115
],
"val": [
810,
639,
233,
1017,
82,
256,
75,
455,
731,
238,
252,
1041,
725,
520,
237,
650,
464,
98,
174,
165,
134,
34,
259,
995,
686,
143,
1049,
716,
18,
1048,
429,
620,
268,
265,
349,
924,
147,
335,
527,
498,
402,
799,
598,
530,
986,
895,
807,
458,
989,
104,
858,
323,
703,
676,
95,
230,
484,
157,
722,
636,
883,
411,
773,
270,
923,
46,
757,
619,
784,
564,
459,
315,
31,
500,
345,
292,
1098,
765,
760,
642,
630,
352,
4,
37,
155,
253,
813,
44,
603,
394,
1,
708,
535,
188,
752,
160,
958,
1033,
130,
261,
382,
21,
940,
746,
41,
25,
69,
977,
117,
84
],
"train": [
521,
97,
212,
436,
465,
835,
410,
195,
591,
288,
988,
828,
450,
127,
42,
93,
92,
970,
873,
239,
424,
1086,
956,
310,
266,
824,
142,
717,
119,
727,
1008,
357,
871,
584,
64,
396,
817,
406,
57,
2,
681,
9,
249,
864,
512,
198,
245,
40,
473,
1089,
299,
182,
601,
379,
321,
955,
763,
486,
766,
22,
906,
744,
171,
316,
702,
467,
370,
898,
274,
820,
1015,
644,
443,
707,
13,
576,
764,
109,
838,
559,
158,
26,
670,
721,
329,
789,
798,
802,
767,
133,
431,
539,
325,
672,
485,
618,
140,
135,
438,
692,
408,
608,
997,
640,
29,
344,
572,
421,
748,
6,
327,
1039,
235,
308,
796,
635,
290,
581,
713,
178,
612,
943,
568,
866,
207,
546,
519,
917,
889,
214,
339,
111,
448,
35,
63,
248,
282,
405,
163,
629,
782,
226,
659,
850,
902,
487,
1058,
592,
540,
596,
505,
831,
801,
575,
1046,
800,
833,
963,
556,
758,
504,
347,
236,
206,
441,
682,
1085,
547,
1040,
507,
87,
1094,
580,
975,
604,
648,
569,
294,
7,
916,
607,
788,
228,
281,
573,
376,
219,
915,
241,
167,
208,
1080,
216,
529,
1075,
94,
690,
907,
542,
578,
146,
120,
968,
49,
862,
211,
680,
447,
815,
586,
1068,
488,
693,
306,
271,
269,
314,
43,
792,
797,
976,
417,
954,
191,
499,
151,
729,
854,
848,
552,
781,
307,
407,
826,
881,
985,
1097,
1069,
397,
423,
508,
168,
76,
998,
627,
439,
860,
258,
229,
105,
645,
590,
751,
377,
759,
791,
597,
80,
106,
887,
891,
386,
753,
1054,
322,
100,
809,
965,
300,
843,
661,
913,
899,
617,
283,
855,
567,
99,
842,
926,
737,
574,
634,
967,
440,
987,
948,
637,
960,
886,
416,
562,
51,
513,
1019,
662,
732,
803,
66,
356,
460,
419,
897,
152,
932,
193,
359,
945,
606,
476,
583,
660,
624,
786,
369,
466,
506,
110,
557,
403,
1012,
351,
378,
1091,
981,
276,
415,
0,
251,
427,
953,
432,
942,
770,
683,
616,
566,
877,
787,
677,
456,
257,
70,
15,
286,
71,
929,
33,
123,
689,
756,
518,
332,
827,
47,
780,
614,
790,
12,
876,
112,
632,
541,
671,
129,
845,
694,
234,
156,
201,
482,
490,
884,
348,
1022,
1021,
302,
380,
921,
738,
231,
1084,
197,
210,
723,
1016,
354,
205,
964,
1038,
517,
755,
593,
1064,
305,
918,
278,
691,
1005,
1006,
242,
675,
200,
27,
116,
492,
430,
1073,
1030,
949,
687,
59,
295,
267,
920,
613,
1009,
398,
227,
880,
890,
90,
1081,
951,
785,
865,
331,
190,
172,
23,
312,
1061,
122,
362,
126,
646,
175,
45,
442,
678,
563,
446,
128,
169,
857,
775,
719,
1070,
991,
806,
623,
754,
77,
615,
684,
145,
463,
199,
304,
714,
666,
816,
863,
277,
868,
223,
333,
841,
878,
532,
808,
469,
39,
183,
85,
179,
53,
14,
65,
861,
704,
565,
1010,
457,
973,
244,
225,
647,
461,
132,
859,
247,
202,
936,
972,
962,
343,
544,
318,
1007,
149,
747,
959,
611,
385,
930,
479,
360,
445,
365,
384,
1092,
328,
287,
350,
409,
387,
743,
555,
1083,
1011,
728,
189,
776,
194,
480,
243,
101,
162,
334,
1018,
643,
649,
137,
847,
298,
452,
749,
170,
994,
554,
730,
79,
1043,
326,
218,
373,
388,
1037,
1045,
62,
284,
971,
705,
836,
888,
983,
622,
91,
516,
153,
543,
957,
837,
892,
933,
922,
673,
319,
341,
311,
651,
264,
783,
309,
551,
742,
136,
510,
879,
946,
453,
577,
654,
550,
900,
814,
969,
641,
570,
289,
10,
1095,
301,
653,
390,
1051,
166,
55,
138,
978,
908,
1034,
1067,
81,
28,
471,
514,
273,
470,
762,
496,
829,
67,
658,
795,
3,
358,
74,
209,
478,
177,
579,
525,
667,
493,
834,
8,
141,
626,
501,
779,
594,
600,
669,
204,
905,
812,
856,
220,
395,
910,
582,
221,
655,
1047,
823,
181,
822,
1026,
846,
320,
11,
894,
1024,
412,
192,
938,
545,
934,
83,
338,
560,
875,
1072,
710,
925,
1002,
414,
280,
494,
54,
522,
999,
184,
392,
609,
30,
451,
477,
148,
587,
367,
413,
240,
1071,
161,
1096,
263,
526,
254,
528,
1074,
720,
872,
553,
159,
114,
293,
536,
992,
449,
337,
711,
903,
774,
399,
735,
131,
38,
663,
697,
88,
68,
426,
150,
558,
180,
364,
72,
1014,
589,
599,
86,
715,
695,
712,
437,
561,
1055,
610,
1078,
974,
852,
769,
303,
696,
503,
909,
1063,
489,
870,
495,
275,
383,
144,
468,
366,
297,
509,
1013,
947,
1050,
481,
674,
1035,
1059,
1087,
709,
840,
400,
656,
342,
1093,
939,
125,
853,
260,
124,
404,
353,
771,
904,
794,
733,
401,
979,
585,
118,
196,
502,
941,
224,
78,
740,
804,
811,
927,
291,
5,
885,
628,
602,
213,
474,
497,
1077,
420,
1029,
462,
16,
990,
793,
203,
313,
851,
103,
422,
839,
1053,
950,
381,
706,
296,
375,
625,
121,
531,
19,
374,
491,
821,
679,
96,
185,
830,
537,
428,
52,
1088,
595,
980,
523,
435,
444,
825,
1036,
1004,
1076,
701,
1023,
389,
657,
548,
317,
914,
475,
685,
533,
984,
222,
107,
893,
869,
186,
20,
102,
928,
246,
89,
1056,
524,
113,
164,
418,
393,
36,
1027,
961,
768,
538,
285,
1000,
700,
262,
993,
726,
1031,
819
]
}
================================================
FILE: datainfo/double_pendulum/data_split_dict_3.json
================================================
{
"test": [
887,
659,
395,
532,
890,
471,
385,
386,
882,
918,
141,
981,
368,
321,
347,
888,
1003,
43,
706,
159,
269,
625,
298,
417,
994,
202,
971,
32,
548,
614,
110,
78,
7,
317,
1065,
483,
571,
677,
773,
92,
90,
243,
850,
1082,
601,
41,
1085,
840,
136,
704,
181,
990,
987,
129,
254,
583,
546,
432,
213,
668,
334,
572,
58,
689,
475,
834,
718,
790,
1038,
862,
1076,
893,
528,
444,
1013,
278,
73,
199,
748,
274,
910,
808,
874,
793,
968,
551,
63,
616,
87,
326,
131,
31,
798,
1071,
310,
474,
308,
813,
975,
963,
392,
479,
531,
960,
26,
134,
970,
757,
267,
487
],
"val": [
223,
897,
759,
344,
81,
173,
876,
829,
448,
229,
853,
691,
341,
329,
930,
104,
99,
714,
665,
445,
694,
191,
335,
244,
275,
823,
669,
227,
512,
1022,
1094,
752,
700,
582,
27,
832,
1056,
107,
996,
806,
307,
270,
988,
864,
936,
776,
320,
189,
372,
1039,
327,
615,
606,
305,
467,
643,
257,
378,
21,
692,
62,
974,
22,
501,
755,
285,
723,
623,
950,
939,
695,
361,
477,
340,
642,
648,
61,
1037,
647,
603,
630,
995,
20,
322,
593,
425,
807,
11,
1005,
561,
1083,
533,
264,
447,
1035,
958,
1030,
732,
737,
649,
441,
277,
519,
830,
1088,
635,
105,
1050,
697,
609
],
"train": [
401,
148,
464,
711,
957,
1084,
420,
1007,
318,
1074,
579,
972,
434,
2,
770,
220,
1010,
540,
192,
740,
857,
788,
545,
408,
352,
163,
469,
118,
758,
534,
506,
587,
504,
396,
863,
760,
231,
794,
1059,
1049,
768,
655,
456,
279,
313,
459,
769,
371,
80,
747,
66,
67,
525,
804,
521,
144,
1081,
894,
1060,
1066,
682,
439,
476,
142,
576,
620,
38,
672,
909,
345,
328,
1019,
193,
1087,
713,
179,
413,
156,
607,
640,
898,
1046,
511,
1027,
253,
558,
135,
781,
1098,
599,
945,
1073,
157,
905,
221,
197,
177,
59,
219,
154,
178,
486,
421,
727,
64,
53,
402,
1097,
355,
751,
228,
646,
631,
852,
452,
627,
1052,
309,
292,
462,
325,
784,
555,
83,
225,
149,
350,
1025,
1031,
468,
427,
969,
49,
846,
836,
262,
473,
1012,
520,
12,
573,
60,
212,
1048,
973,
820,
671,
374,
460,
249,
859,
171,
172,
198,
218,
407,
746,
908,
502,
112,
1036,
450,
594,
1089,
999,
931,
809,
301,
137,
236,
819,
639,
753,
843,
1002,
1069,
75,
207,
100,
214,
300,
346,
720,
8,
578,
89,
188,
272,
1070,
418,
84,
725,
167,
238,
875,
800,
470,
390,
1055,
835,
211,
312,
217,
656,
0,
405,
151,
941,
1026,
174,
375,
586,
304,
814,
315,
265,
855,
1072,
929,
143,
565,
778,
485,
779,
337,
201,
562,
399,
186,
709,
500,
608,
1080,
362,
29,
743,
621,
42,
48,
121,
949,
1040,
696,
266,
557,
685,
6,
690,
624,
10,
1044,
955,
952,
719,
424,
16,
680,
162,
415,
287,
575,
821,
411,
363,
780,
503,
754,
71,
14,
86,
935,
702,
394,
299,
717,
721,
921,
589,
977,
693,
658,
889,
916,
496,
114,
296,
185,
938,
842,
712,
792,
772,
19,
233,
454,
224,
1018,
896,
789,
966,
155,
698,
106,
216,
687,
872,
242,
449,
404,
82,
535,
480,
209,
998,
482,
629,
248,
113,
1041,
705,
1016,
673,
775,
40,
17,
370,
1034,
68,
841,
168,
165,
507,
9,
801,
762,
39,
259,
364,
316,
633,
803,
25,
319,
666,
239,
708,
914,
728,
145,
1057,
412,
815,
472,
518,
699,
108,
619,
397,
735,
564,
652,
684,
745,
251,
119,
338,
951,
943,
379,
97,
797,
387,
559,
1061,
466,
458,
351,
844,
1032,
324,
158,
679,
282,
161,
443,
208,
828,
15,
899,
605,
526,
103,
670,
653,
1045,
543,
451,
517,
597,
724,
976,
303,
65,
516,
1000,
574,
343,
1047,
290,
96,
357,
674,
416,
654,
667,
288,
617,
1058,
1095,
513,
437,
77,
57,
764,
498,
457,
736,
414,
822,
993,
206,
170,
928,
263,
393,
356,
681,
831,
628,
774,
937,
816,
913,
196,
138,
956,
750,
660,
200,
923,
294,
1064,
283,
510,
1028,
366,
707,
810,
947,
650,
595,
610,
246,
234,
95,
331,
153,
284,
854,
932,
638,
1001,
783,
169,
481,
72,
286,
55,
1075,
817,
901,
911,
455,
926,
91,
76,
741,
373,
489,
756,
777,
380,
505,
927,
967,
596,
367,
536,
180,
560,
252,
891,
1051,
1009,
585,
289,
538,
552,
946,
1024,
917,
644,
730,
591,
877,
377,
686,
895,
256,
989,
115,
442,
1020,
281,
924,
676,
786,
400,
742,
215,
984,
497,
388,
446,
339,
342,
30,
824,
997,
463,
365,
839,
851,
920,
664,
120,
866,
93,
881,
922,
925,
641,
884,
261,
845,
900,
140,
1015,
194,
867,
260,
865,
749,
715,
235,
478,
811,
767,
74,
802,
688,
570,
1090,
210,
491,
622,
147,
1067,
892,
645,
731,
176,
205,
391,
490,
1068,
23,
314,
222,
598,
791,
302,
550,
237,
703,
744,
150,
965,
273,
139,
190,
733,
453,
612,
611,
509,
98,
166,
152,
410,
837,
495,
529,
983,
878,
964,
164,
422,
240,
907,
661,
85,
701,
569,
358,
406,
613,
376,
785,
403,
959,
465,
255,
636,
247,
79,
184,
954,
332,
563,
782,
566,
102,
567,
109,
796,
541,
523,
663,
37,
1029,
306,
592,
291,
493,
117,
1093,
683,
787,
24,
183,
384,
101,
675,
94,
761,
45,
886,
241,
146,
1096,
5,
1033,
985,
544,
577,
1078,
111,
1092,
871,
795,
726,
860,
124,
268,
982,
383,
333,
28,
133,
436,
1004,
435,
738,
604,
276,
1017,
879,
849,
18,
716,
428,
430,
739,
508,
953,
381,
915,
979,
833,
44,
651,
722,
554,
880,
618,
1063,
382,
1006,
182,
580,
883,
389,
3,
568,
130,
126,
438,
336,
870,
271,
369,
311,
600,
398,
662,
991,
359,
1023,
869,
160,
323,
942,
514,
527,
88,
729,
827,
494,
70,
766,
127,
47,
1042,
56,
1,
330,
484,
52,
433,
992,
539,
632,
903,
1043,
848,
51,
409,
584,
934,
499,
1021,
1079,
280,
245,
799,
1086,
547,
125,
1014,
226,
360,
948,
1054,
553,
258,
128,
818,
116,
626,
54,
838,
549,
537,
961,
904,
902,
919,
515,
175,
873,
861,
492,
13,
50,
590,
440,
203,
524,
710,
35,
46,
250,
122,
293,
602,
69,
232,
349,
556,
295,
1091,
765,
678,
771,
933,
763,
33,
906,
734,
986,
1062,
522,
637,
488,
4,
204,
1008,
423,
36,
419,
581,
429,
297,
426,
978,
354,
1053,
885,
812,
530,
1011,
431,
132,
940,
353,
634,
825,
1077,
657,
847,
868,
348,
944,
187,
588,
858,
856,
826,
962,
195,
542,
34,
123,
805,
230,
980,
461,
912
]
}
================================================
FILE: datainfo/elastic_pendulum/data_split_dict_1.json
================================================
{
"test": [
187,
370,
362,
470,
57,
938,
678,
3,
708,
1137,
730,
993,
846,
1033,
409,
746,
985,
114,
872,
420,
1062,
264,
1049,
785,
11,
551,
940,
723,
704,
1052,
828,
475,
1105,
408,
25,
464,
1028,
345,
348,
806,
631,
89,
961,
60,
1002,
758,
1142,
1066,
335,
221,
1041,
898,
177,
767,
1123,
751,
354,
848,
827,
497,
983,
70,
805,
1034,
1022,
581,
621,
388,
1039,
1171,
1025,
681,
247,
607,
380,
204,
1139,
852,
44,
593,
941,
448,
472,
707,
477,
1132,
1015,
896,
454,
1080,
59,
864,
443,
780,
18,
1108,
52,
45,
62,
650,
209,
468,
545,
912,
4,
886,
798,
58,
999,
192,
429,
777,
967,
920,
1014,
241,
522,
129,
1165,
275
],
"val": [
411,
892,
626,
333,
17,
1071,
516,
303,
399,
1151,
960,
106,
504,
198,
605,
1172,
918,
690,
587,
210,
101,
355,
205,
387,
1000,
521,
637,
720,
1186,
890,
888,
847,
175,
471,
583,
922,
1096,
1046,
839,
604,
38,
870,
899,
574,
8,
133,
258,
578,
426,
162,
761,
305,
1122,
939,
317,
78,
879,
1036,
313,
1053,
1167,
217,
991,
257,
611,
120,
1093,
657,
808,
1178,
457,
923,
451,
873,
1183,
328,
72,
299,
813,
36,
461,
42,
884,
428,
1044,
519,
222,
529,
385,
862,
703,
791,
638,
48,
233,
970,
1016,
659,
931,
603,
558,
344,
1079,
326,
342,
142,
594,
705,
378,
224,
550,
511,
575,
29,
927,
34,
170,
144,
66,
1196
],
"train": [
483,
489,
341,
685,
96,
1,
1148,
542,
656,
744,
1085,
613,
855,
775,
680,
111,
1081,
648,
308,
733,
1023,
189,
759,
61,
482,
639,
1190,
228,
776,
596,
885,
1024,
1187,
1048,
295,
1158,
1146,
1094,
819,
684,
540,
793,
856,
1099,
622,
262,
773,
609,
107,
263,
260,
1160,
958,
474,
672,
739,
442,
478,
887,
1083,
734,
612,
1084,
956,
843,
55,
1067,
750,
823,
9,
532,
112,
282,
329,
520,
874,
359,
1125,
1102,
514,
668,
1164,
459,
197,
844,
753,
635,
633,
185,
1195,
966,
201,
906,
1059,
569,
146,
580,
39,
1047,
446,
597,
276,
43,
1042,
533,
561,
166,
1003,
154,
778,
647,
243,
67,
768,
617,
1149,
634,
981,
649,
450,
507,
215,
1057,
740,
691,
710,
1120,
858,
889,
682,
71,
531,
869,
415,
849,
444,
237,
716,
214,
764,
735,
765,
595,
486,
90,
834,
989,
437,
292,
1127,
549,
330,
16,
304,
559,
1168,
104,
186,
916,
719,
99,
623,
476,
293,
498,
286,
1136,
905,
503,
143,
952,
917,
167,
242,
26,
797,
149,
315,
652,
897,
871,
234,
271,
223,
1029,
769,
238,
673,
924,
294,
1199,
567,
33,
207,
821,
945,
795,
895,
548,
784,
729,
987,
412,
908,
986,
965,
375,
427,
6,
891,
456,
195,
980,
641,
435,
419,
1004,
925,
721,
538,
80,
718,
762,
679,
1063,
131,
937,
311,
307,
552,
485,
1054,
1043,
959,
1176,
840,
1037,
432,
699,
1177,
812,
424,
1157,
752,
360,
190,
68,
316,
693,
1194,
383,
384,
951,
517,
741,
1110,
1072,
671,
676,
219,
179,
1121,
30,
396,
1106,
1005,
585,
402,
321,
954,
943,
606,
749,
255,
859,
530,
865,
390,
395,
976,
877,
556,
268,
1007,
755,
957,
909,
702,
933,
509,
929,
130,
582,
289,
1073,
277,
756,
973,
1100,
394,
525,
81,
598,
100,
147,
379,
196,
644,
851,
2,
782,
351,
824,
1090,
56,
995,
586,
636,
919,
660,
108,
199,
1045,
280,
1091,
1107,
553,
757,
421,
193,
655,
653,
119,
494,
492,
1019,
675,
343,
1134,
140,
339,
692,
453,
1018,
438,
1163,
964,
571,
695,
640,
28,
696,
95,
51,
510,
164,
493,
910,
54,
1188,
156,
911,
998,
1170,
123,
838,
158,
184,
760,
401,
371,
1141,
974,
269,
235,
134,
770,
75,
372,
747,
227,
296,
1032,
374,
1173,
487,
97,
434,
127,
893,
1129,
841,
206,
591,
15,
338,
125,
361,
619,
694,
113,
366,
599,
413,
469,
404,
842,
88,
1155,
787,
200,
1092,
1068,
738,
714,
463,
630,
543,
913,
334,
152,
481,
1114,
645,
285,
462,
854,
188,
337,
267,
425,
357,
984,
717,
137,
1031,
121,
947,
270,
1131,
473,
484,
1086,
1013,
836,
570,
47,
353,
82,
1088,
832,
84,
501,
226,
336,
1075,
666,
35,
216,
781,
508,
50,
465,
491,
23,
763,
669,
926,
674,
132,
687,
236,
77,
155,
278,
350,
1058,
715,
1082,
534,
135,
467,
178,
5,
309,
397,
417,
37,
754,
670,
433,
902,
815,
524,
479,
544,
664,
79,
701,
789,
528,
748,
829,
148,
332,
495,
904,
539,
381,
745,
790,
69,
1061,
414,
624,
398,
727,
368,
202,
537,
174,
535,
590,
1162,
1051,
643,
978,
1087,
642,
602,
915,
115,
972,
779,
1038,
731,
358,
816,
452,
850,
0,
331,
1145,
833,
515,
447,
568,
382,
663,
1135,
139,
1192,
13,
589,
689,
867,
290,
1180,
392,
213,
124,
430,
502,
365,
722,
172,
935,
651,
318,
279,
955,
950,
151,
291,
736,
724,
266,
248,
31,
21,
883,
665,
194,
1184,
572,
837,
254,
284,
820,
1150,
418,
546,
441,
122,
1021,
499,
725,
662,
963,
809,
406,
712,
391,
246,
455,
1124,
1065,
85,
875,
928,
914,
783,
536,
1074,
766,
168,
527,
393,
356,
403,
1027,
53,
994,
997,
436,
700,
159,
180,
386,
860,
831,
620,
126,
87,
608,
1111,
169,
1154,
565,
713,
1011,
319,
208,
646,
163,
377,
523,
73,
709,
1069,
281,
625,
573,
1117,
422,
230,
490,
506,
814,
244,
231,
654,
1161,
512,
772,
541,
181,
7,
944,
239,
932,
627,
176,
410,
1098,
407,
868,
802,
737,
907,
1153,
141,
901,
363,
63,
19,
880,
1008,
661,
24,
1156,
1101,
992,
1143,
405,
1104,
1115,
711,
449,
794,
1119,
562,
440,
1030,
698,
811,
1006,
253,
683,
49,
161,
1070,
975,
306,
182,
979,
706,
566,
688,
1109,
323,
32,
1060,
145,
211,
1197,
300,
616,
526,
726,
109,
312,
817,
1077,
153,
183,
1198,
969,
996,
881,
774,
513,
1133,
157,
799,
505,
367,
297,
822,
1179,
22,
76,
1095,
1017,
900,
1193,
894,
1055,
249,
20,
225,
250,
94,
818,
1103,
962,
557,
103,
1064,
251,
310,
592,
810,
191,
953,
1130,
792,
968,
232,
948,
658,
588,
826,
92,
458,
771,
91,
287,
876,
369,
252,
203,
314,
1138,
554,
1169,
265,
364,
677,
480,
1175,
835,
796,
632,
803,
220,
256,
1097,
466,
615,
324,
65,
64,
1113,
320,
400,
460,
327,
610,
743,
863,
866,
10,
27,
853,
325,
667,
212,
102,
322,
488,
728,
259,
1056,
301,
555,
825,
882,
445,
105,
1010,
1009,
1152,
697,
171,
1040,
118,
165,
431,
600,
804,
245,
1189,
1020,
1116,
845,
1001,
423,
93,
14,
686,
628,
12,
1026,
1144,
46,
1126,
110,
283,
1181,
1012,
934,
577,
302,
373,
273,
83,
579,
229,
563,
584,
1166,
1140,
800,
601,
629,
117,
349,
128,
1089,
150,
807,
1185,
389,
74,
416,
40,
1035,
971,
788,
564,
1159,
949,
500,
732,
988,
618,
990,
930,
298,
116,
1118,
346,
376,
261,
861,
518,
614,
340,
1191,
274,
946,
1112,
1076,
173,
136,
86,
41,
742,
1078,
240,
1182,
786,
496,
547,
1050,
942,
903,
936,
352,
560,
1147,
857,
98,
977,
272,
218,
439,
347,
138,
801,
576,
830,
1128,
878,
982,
160,
1174,
288,
921
]
}
================================================
FILE: datainfo/elastic_pendulum/data_split_dict_2.json
================================================
{
"test": [
789,
3,
1071,
376,
321,
261,
523,
764,
43,
83,
51,
138,
235,
169,
1167,
510,
352,
737,
742,
116,
65,
866,
123,
431,
501,
544,
1163,
1069,
1113,
464,
559,
100,
120,
217,
391,
17,
699,
154,
750,
1048,
1001,
425,
638,
832,
1039,
1060,
1032,
621,
633,
982,
1181,
340,
664,
454,
996,
935,
718,
1171,
931,
1152,
1055,
1025,
1020,
571,
1003,
511,
1086,
944,
818,
330,
1156,
741,
724,
1178,
1075,
849,
912,
372,
1146,
1052,
736,
1162,
1044,
279,
355,
665,
361,
48,
472,
483,
363,
1147,
336,
1076,
867,
778,
652,
952,
745,
56,
1191,
549,
1028,
911,
1114,
761,
1042,
805,
882,
324,
1190,
73,
434,
515,
631,
346,
739,
173,
187,
115
],
"val": [
954,
706,
296,
375,
625,
121,
943,
531,
19,
374,
491,
821,
679,
96,
942,
185,
983,
537,
951,
428,
902,
52,
605,
595,
21,
1158,
435,
444,
825,
746,
215,
986,
1175,
1037,
701,
1108,
389,
657,
548,
317,
1109,
475,
685,
533,
25,
222,
107,
237,
768,
186,
20,
102,
104,
246,
89,
1151,
524,
113,
1029,
418,
393,
1046,
1117,
9,
570,
1101,
525,
1160,
467,
164,
513,
150,
910,
476,
505,
64,
474,
928,
196,
349,
1149,
269,
68,
518,
1099,
287,
36,
859,
536,
530,
698,
294,
671,
997,
804,
1085,
917,
49,
208,
647,
191,
461,
968,
314,
822,
540,
93,
918,
1194,
63,
1000,
690,
585,
231,
704,
8,
74,
310,
507,
88
],
"train": [
689,
907,
403,
285,
1073,
796,
316,
140,
1183,
132,
861,
1047,
167,
556,
985,
999,
1185,
891,
984,
199,
717,
40,
937,
354,
405,
271,
1155,
1125,
262,
900,
292,
304,
151,
909,
397,
580,
820,
646,
667,
1057,
183,
684,
760,
450,
1159,
726,
198,
747,
880,
1012,
335,
719,
39,
913,
84,
168,
207,
1051,
396,
618,
1088,
834,
567,
146,
923,
814,
921,
629,
299,
1139,
45,
734,
216,
758,
302,
424,
128,
60,
406,
85,
447,
205,
710,
547,
566,
32,
156,
18,
892,
105,
1170,
239,
659,
1061,
797,
76,
248,
541,
641,
783,
969,
1199,
508,
94,
58,
394,
639,
863,
479,
819,
1034,
487,
947,
862,
283,
305,
327,
438,
744,
975,
386,
723,
202,
756,
979,
236,
1092,
1094,
774,
343,
219,
920,
360,
945,
601,
810,
998,
568,
932,
197,
603,
830,
527,
808,
940,
46,
792,
875,
816,
1005,
506,
916,
634,
976,
244,
1009,
2,
1198,
290,
766,
1182,
840,
417,
182,
249,
210,
443,
623,
307,
590,
1043,
592,
925,
155,
858,
675,
90,
427,
442,
848,
846,
1093,
692,
1030,
1131,
35,
255,
111,
1166,
469,
1016,
22,
614,
377,
1018,
26,
195,
499,
165,
37,
333,
883,
899,
247,
906,
981,
288,
365,
616,
693,
1070,
7,
1135,
441,
1179,
874,
188,
992,
851,
347,
1054,
456,
258,
991,
529,
1110,
1066,
970,
69,
370,
457,
735,
1137,
546,
1050,
13,
879,
312,
348,
872,
23,
226,
611,
1130,
678,
445,
57,
557,
569,
853,
668,
532,
272,
806,
379,
977,
974,
223,
266,
385,
127,
971,
31,
562,
1112,
1058,
624,
753,
798,
1161,
300,
281,
980,
552,
253,
42,
142,
228,
486,
573,
436,
1103,
407,
459,
781,
1023,
1142,
780,
1077,
225,
175,
517,
714,
163,
29,
1157,
622,
14,
6,
961,
473,
274,
660,
157,
329,
295,
1017,
241,
617,
1065,
1021,
158,
731,
828,
811,
1064,
227,
378,
888,
565,
1063,
586,
1116,
411,
1098,
707,
232,
229,
87,
1100,
423,
833,
1041,
838,
1134,
612,
133,
896,
214,
109,
578,
946,
887,
645,
398,
1121,
539,
757,
670,
212,
211,
687,
572,
865,
1196,
700,
80,
584,
149,
588,
662,
351,
1136,
972,
542,
708,
325,
521,
455,
1111,
1096,
957,
432,
306,
786,
655,
1033,
458,
591,
850,
1193,
134,
380,
1174,
1148,
1144,
1040,
1010,
962,
200,
99,
683,
876,
384,
282,
308,
482,
1192,
1123,
857,
357,
894,
583,
669,
117,
615,
250,
785,
466,
564,
139,
575,
1090,
607,
50,
415,
1053,
1097,
868,
950,
66,
498,
1180,
1118,
174,
152,
953,
193,
967,
898,
1084,
77,
1019,
847,
53,
705,
776,
345,
649,
1056,
268,
581,
201,
490,
331,
318,
635,
619,
672,
519,
598,
27,
632,
108,
492,
534,
126,
59,
276,
812,
0,
251,
855,
484,
122,
448,
270,
1124,
596,
267,
416,
941,
835,
949,
964,
749,
930,
800,
277,
864,
356,
70,
15,
286,
71,
1013,
33,
1102,
1127,
460,
439,
897,
613,
332,
419,
172,
784,
504,
703,
933,
12,
844,
112,
606,
772,
948,
129,
4,
593,
408,
234,
993,
666,
637,
179,
512,
965,
839,
966,
110,
676,
463,
751,
421,
815,
465,
410,
1078,
794,
92,
1189,
1015,
485,
1067,
955,
233,
597,
1195,
119,
1059,
576,
135,
278,
1133,
440,
1035,
759,
147,
1107,
929,
344,
47,
245,
252,
914,
608,
654,
256,
369,
430,
885,
257,
339,
769,
145,
97,
680,
206,
323,
807,
1022,
752,
1138,
554,
787,
242,
604,
555,
738,
446,
721,
809,
171,
359,
106,
711,
130,
642,
322,
190,
1014,
488,
802,
362,
673,
178,
677,
904,
653,
1083,
788,
1105,
630,
799,
328,
775,
713,
1045,
350,
409,
387,
563,
1188,
651,
702,
915,
1122,
189,
919,
194,
480,
243,
101,
162,
334,
1106,
755,
905,
137,
813,
298,
452,
889,
170,
371,
767,
716,
79,
852,
326,
218,
373,
388,
644,
827,
62,
284,
535,
836,
574,
886,
238,
41,
681,
91,
643,
516,
153,
543,
1141,
901,
990,
648,
520,
95,
1049,
893,
658,
319,
661,
341,
311,
765,
264,
926,
309,
551,
881,
136,
1132,
1095,
958,
762,
453,
577,
1008,
550,
34,
963,
1,
729,
1024,
289,
10,
826,
301,
1074,
1087,
390,
1145,
166,
55,
1091,
640,
691,
1140,
790,
779,
1164,
934,
81,
28,
471,
514,
273,
470,
903,
1082,
496,
1129,
67,
777,
1002,
824,
939,
1081,
358,
1173,
209,
478,
177,
579,
1026,
1080,
493,
987,
694,
627,
1154,
141,
626,
1104,
922,
594,
600,
791,
204,
1143,
873,
1011,
220,
395,
620,
582,
1187,
221,
770,
1115,
973,
181,
1165,
803,
728,
842,
1120,
320,
11,
650,
1184,
682,
412,
192,
636,
545,
230,
743,
1089,
956,
763,
730,
338,
560,
368,
1169,
841,
890,
960,
748,
727,
259,
414,
280,
494,
54,
522,
402,
184,
754,
392,
609,
30,
878,
451,
477,
148,
989,
587,
782,
801,
367,
413,
240,
720,
1168,
161,
1197,
263,
526,
732,
254,
528,
1172,
854,
433,
924,
553,
159,
114,
293,
1119,
176,
449,
337,
843,
686,
725,
869,
399,
871,
131,
38,
663,
829,
697,
1079,
884,
1186,
426,
988,
1031,
558,
180,
364,
72,
98,
589,
837,
877,
599,
86,
715,
695,
712,
437,
561,
265,
610,
500,
160,
1007,
1150,
303,
696,
856,
503,
429,
823,
1153,
1027,
489,
538,
495,
275,
383,
817,
144,
468,
366,
297,
509,
722,
927,
44,
795,
481,
674,
1128,
1004,
1177,
709,
995,
400,
656,
342,
870,
1068,
82,
125,
936,
260,
124,
908,
404,
353,
771,
143,
938,
733,
401,
382,
1072,
118,
1038,
502,
773,
224,
78,
740,
978,
959,
24,
291,
5,
75,
628,
602,
1126,
213,
1036,
497,
1176,
420,
61,
462,
831,
16,
845,
688,
793,
860,
203,
313,
1006,
103,
422,
994,
895,
1062,
315,
381
]
}
================================================
FILE: datainfo/elastic_pendulum/data_split_dict_3.json
================================================
{
"test": [
1194,
644,
1184,
23,
694,
620,
1067,
1157,
895,
1155,
486,
883,
555,
1153,
210,
1152,
1065,
942,
771,
1121,
283,
737,
642,
695,
86,
319,
539,
597,
835,
404,
64,
1096,
221,
157,
14,
634,
1161,
483,
1035,
571,
677,
773,
92,
90,
243,
850,
1167,
601,
41,
1181,
840,
136,
704,
181,
990,
987,
129,
254,
583,
546,
432,
213,
1109,
668,
334,
572,
58,
689,
475,
834,
1093,
718,
790,
1038,
862,
1172,
893,
528,
444,
1013,
278,
73,
199,
748,
274,
910,
808,
874,
793,
968,
551,
63,
616,
87,
326,
131,
31,
798,
1071,
310,
474,
308,
813,
975,
1125,
1107,
963,
392,
479,
1128,
531,
960,
26,
134,
1189,
970,
757,
267,
1114,
487
],
"val": [
710,
822,
925,
35,
46,
250,
829,
122,
293,
602,
878,
1029,
882,
232,
911,
349,
556,
295,
1190,
765,
678,
1079,
856,
467,
763,
33,
227,
734,
949,
972,
1145,
1158,
522,
637,
901,
852,
965,
1047,
4,
204,
159,
423,
1097,
36,
419,
581,
429,
297,
426,
649,
354,
1148,
277,
870,
812,
530,
298,
431,
132,
603,
353,
1051,
825,
175,
696,
570,
375,
645,
390,
69,
247,
460,
554,
923,
446,
1147,
163,
346,
897,
459,
683,
659,
208,
198,
891,
383,
671,
488,
1170,
455,
1024,
1115,
269,
55,
214,
772,
615,
540,
1086,
640,
379,
745,
363,
655,
611,
934,
514,
756,
43,
124,
45,
1002,
1119,
1074,
722,
954,
680,
123,
272,
1098
],
"train": [
464,
1062,
890,
22,
660,
427,
290,
167,
469,
1159,
341,
712,
560,
457,
77,
952,
639,
462,
1185,
1195,
654,
135,
83,
797,
730,
372,
650,
218,
225,
787,
285,
265,
343,
279,
452,
1028,
397,
951,
948,
421,
855,
625,
158,
846,
1004,
1009,
753,
72,
466,
559,
914,
80,
591,
738,
142,
1120,
847,
408,
708,
973,
552,
236,
364,
1186,
207,
606,
1103,
338,
741,
622,
282,
178,
991,
263,
1054,
1112,
612,
586,
1072,
1193,
371,
1129,
480,
66,
1176,
781,
691,
507,
735,
853,
91,
575,
386,
192,
928,
67,
335,
785,
656,
585,
220,
789,
374,
347,
706,
231,
196,
188,
752,
681,
348,
324,
1177,
309,
151,
727,
929,
525,
688,
100,
95,
1141,
1124,
352,
519,
1090,
624,
1075,
623,
266,
351,
219,
723,
1133,
470,
1085,
75,
38,
1150,
510,
1031,
998,
800,
8,
641,
676,
370,
104,
288,
766,
415,
286,
292,
521,
411,
42,
1182,
443,
1110,
1092,
811,
873,
1073,
731,
842,
977,
350,
561,
788,
424,
1179,
312,
876,
906,
141,
321,
228,
564,
294,
791,
414,
922,
76,
118,
328,
138,
1166,
1156,
356,
628,
714,
841,
777,
1142,
799,
803,
103,
875,
844,
638,
1063,
1055,
726,
953,
1113,
313,
518,
315,
449,
935,
1160,
65,
881,
482,
6,
633,
784,
296,
498,
10,
1081,
337,
331,
865,
867,
434,
1169,
380,
605,
394,
154,
889,
884,
393,
412,
143,
217,
500,
587,
783,
815,
1052,
945,
535,
262,
961,
1007,
635,
12,
284,
61,
1089,
767,
672,
1076,
1197,
607,
717,
1117,
327,
999,
866,
1149,
0,
399,
1010,
345,
230,
994,
667,
121,
1036,
49,
1144,
1199,
1101,
992,
489,
451,
249,
533,
979,
927,
137,
666,
595,
496,
770,
959,
212,
534,
1099,
978,
1100,
197,
89,
697,
686,
171,
913,
1017,
477,
1146,
211,
558,
827,
1108,
201,
156,
619,
1039,
179,
915,
1088,
679,
437,
169,
657,
96,
27,
246,
502,
627,
684,
473,
647,
1005,
53,
739,
1046,
303,
485,
851,
29,
955,
373,
576,
699,
172,
15,
1131,
287,
930,
725,
589,
238,
653,
148,
454,
304,
1171,
511,
59,
318,
170,
84,
97,
857,
1162,
48,
947,
112,
481,
252,
703,
941,
1026,
617,
981,
837,
1021,
16,
1042,
162,
664,
838,
760,
517,
1059,
251,
1033,
420,
1154,
886,
71,
1198,
325,
1104,
257,
1015,
299,
830,
1084,
614,
848,
506,
721,
966,
849,
715,
596,
1018,
107,
888,
759,
418,
305,
669,
504,
1014,
608,
1027,
536,
368,
19,
233,
1122,
224,
796,
1020,
447,
234,
698,
11,
545,
956,
750,
106,
216,
907,
34,
1135,
242,
899,
200,
82,
818,
450,
1140,
476,
209,
21,
402,
1053,
161,
248,
113,
387,
357,
819,
543,
833,
32,
1095,
1048,
916,
1087,
40,
366,
17,
396,
578,
1006,
439,
1080,
68,
983,
355,
682,
445,
168,
165,
648,
9,
898,
743,
39,
259,
119,
316,
1061,
950,
25,
472,
574,
782,
405,
513,
971,
1056,
456,
858,
674,
145,
289,
57,
986,
786,
1044,
1180,
709,
416,
821,
108,
1019,
940,
693,
828,
744,
646,
805,
177,
503,
239,
905,
2,
859,
526,
824,
764,
920,
317,
871,
562,
417,
1049,
631,
300,
378,
610,
206,
588,
174,
1043,
407,
149,
912,
621,
1008,
193,
189,
1139,
673,
565,
817,
155,
114,
816,
361,
806,
458,
413,
340,
362,
700,
253,
794,
516,
685,
845,
705,
832,
1187,
900,
1078,
754,
367,
401,
976,
755,
185,
969,
1041,
180,
1105,
768,
468,
652,
944,
854,
974,
629,
144,
573,
110,
1130,
742,
1060,
775,
153,
186,
724,
301,
60,
924,
520,
505,
936,
538,
939,
579,
1037,
1058,
860,
1050,
461,
377,
807,
1064,
99,
256,
1057,
448,
115,
442,
78,
281,
917,
885,
557,
932,
400,
872,
701,
215,
747,
497,
388,
1034,
339,
342,
30,
919,
344,
670,
463,
365,
1003,
1016,
270,
780,
120,
1173,
93,
894,
187,
320,
663,
173,
261,
762,
636,
191,
140,
264,
194,
746,
260,
769,
880,
728,
235,
478,
964,
903,
74,
988,
809,
1025,
1143,
594,
1188,
1094,
736,
491,
758,
147,
879,
329,
1196,
861,
1138,
176,
205,
391,
490,
1164,
1083,
314,
222,
1132,
692,
598,
609,
938,
687,
302,
550,
237,
826,
887,
599,
150,
957,
273,
139,
190,
863,
453,
720,
1151,
509,
98,
166,
152,
410,
658,
1000,
495,
702,
529,
690,
593,
582,
802,
425,
164,
422,
240,
512,
776,
85,
823,
569,
358,
406,
613,
376,
931,
403,
716,
630,
465,
255,
661,
1163,
1030,
79,
184,
761,
332,
563,
962,
566,
102,
567,
109,
943,
732,
541,
523,
779,
37,
1123,
306,
592,
291,
713,
493,
1118,
117,
1192,
804,
933,
24,
183,
384,
675,
101,
792,
94,
896,
1070,
864,
241,
146,
892,
5,
995,
1127,
105,
544,
577,
1174,
751,
111,
1040,
385,
542,
1178,
749,
982,
707,
1011,
1069,
268,
1134,
1045,
333,
28,
133,
436,
229,
435,
868,
604,
711,
276,
548,
223,
996,
18,
909,
428,
1191,
430,
869,
719,
778,
508,
1102,
381,
937,
441,
993,
44,
820,
651,
1082,
1032,
926,
665,
618,
471,
382,
1068,
182,
580,
81,
814,
389,
740,
733,
3,
568,
130,
126,
438,
336,
195,
271,
369,
311,
600,
398,
662,
395,
359,
1116,
322,
160,
323,
501,
1066,
527,
88,
729,
1126,
985,
494,
70,
902,
127,
47,
1136,
1168,
56,
877,
1,
839,
795,
330,
484,
52,
433,
532,
1106,
632,
275,
1137,
1012,
774,
51,
409,
1091,
584,
643,
499,
7,
843,
1175,
1022,
1165,
280,
810,
245,
958,
967,
836,
908,
547,
125,
202,
226,
360,
62,
801,
831,
997,
553,
989,
258,
128,
1183,
116,
626,
1111,
54,
1001,
549,
537,
20,
980,
244,
307,
515,
1023,
1077,
984,
492,
13,
50,
590,
440,
921,
904,
918,
203,
946,
524
]
}
================================================
FILE: datainfo/fire/data_split_dict_1.json
================================================
{
"test": [
248,
491,
35,
873,
603,
402,
517,
866,
511,
903,
601,
290,
310,
194,
686,
849,
519,
951,
512,
728,
968,
917,
340,
760,
123,
303,
880,
741,
644,
190,
102,
657,
569,
857,
426,
960,
296,
470,
989,
224,
693,
236,
353,
238,
566,
990,
448,
994,
227,
540,
980,
743,
432,
221,
702,
390,
902,
9,
554,
665,
26,
22,
31,
325,
923,
104,
605,
234,
995,
738,
272,
456,
712,
2,
785,
780,
622,
443,
399,
855,
914,
29,
499,
96,
214,
807,
388,
667,
483,
460,
779,
507,
120,
261,
64,
782,
821,
867,
582,
137
],
"val": [
992,
564,
93,
185,
598,
563,
181,
650,
235,
28,
614,
469,
339,
627,
805,
638,
553,
551,
1,
354,
895,
365,
496,
423,
516,
856,
567,
583,
373,
492,
57,
436,
210,
574,
796,
531,
132,
828,
524,
758,
802,
392,
5,
746,
623,
854,
675,
275,
936,
361,
591,
352,
526,
414,
237,
891,
552,
204,
789,
12,
232,
514,
172,
174,
662,
403,
592,
607,
629,
720,
315,
44,
480,
30,
750,
501,
379,
904,
860,
533,
167,
797,
110,
520,
679,
449,
88,
383,
755,
690,
794,
719,
561,
375,
177,
680,
424,
413,
816,
761
],
"train": [
852,
836,
280,
575,
208,
609,
931,
725,
47,
984,
337,
581,
155,
692,
862,
790,
572,
241,
991,
788,
913,
498,
107,
767,
839,
930,
300,
386,
490,
920,
49,
216,
886,
73,
125,
565,
262,
717,
596,
669,
309,
803,
876,
848,
580,
864,
447,
484,
418,
147,
556,
323,
37,
586,
438,
409,
52,
599,
687,
878,
716,
243,
131,
562,
219,
95,
808,
82,
696,
306,
901,
144,
998,
415,
548,
4,
704,
707,
143,
269,
879,
479,
284,
888,
950,
294,
949,
13,
705,
515,
751,
356,
121,
621,
872,
894,
982,
472,
477,
863,
954,
295,
729,
378,
641,
753,
58,
798,
371,
271,
973,
169,
239,
207,
358,
955,
600,
509,
434,
585,
34,
634,
433,
311,
701,
475,
331,
473,
615,
7,
421,
266,
56,
384,
183,
391,
870,
200,
100,
59,
494,
893,
146,
685,
396,
892,
316,
176,
180,
834,
766,
19,
806,
68,
795,
453,
231,
616,
53,
246,
912,
709,
814,
23,
336,
570,
811,
481,
113,
168,
762,
362,
697,
730,
543,
25,
289,
197,
11,
987,
628,
233,
734,
655,
242,
276,
77,
915,
698,
444,
594,
67,
89,
393,
154,
673,
417,
299,
726,
199,
510,
39,
206,
530,
942,
446,
482,
525,
768,
635,
267,
370,
979,
678,
837,
966,
559,
842,
568,
293,
877,
109,
148,
99,
381,
345,
134,
544,
244,
135,
590,
158,
459,
368,
838,
883,
366,
405,
945,
425,
926,
910,
865,
940,
50,
343,
186,
815,
784,
529,
184,
642,
478,
338,
254,
817,
285,
677,
777,
829,
550,
874,
824,
981,
179,
749,
226,
927,
0,
633,
357,
257,
223,
493,
537,
410,
695,
69,
971,
905,
6,
660,
666,
55,
700,
380,
350,
631,
62,
215,
534,
182,
487,
653,
742,
360,
875,
139,
708,
451,
75,
145,
419,
699,
706,
124,
15,
882,
474,
630,
412,
97,
972,
394,
497,
127,
142,
151,
868,
209,
613,
307,
454,
61,
108,
527,
318,
495,
825,
166,
377,
422,
906,
195,
963,
947,
654,
869,
781,
539,
652,
51,
321,
130,
953,
268,
733,
278,
612,
476,
196,
178,
326,
201,
640,
401,
959,
291,
757,
648,
452,
79,
455,
193,
884,
658,
535,
63,
43,
304,
853,
840,
359,
84,
928,
282,
810,
916,
152,
159,
964,
385,
81,
188,
676,
643,
745,
770,
786,
437,
140,
334,
372,
312,
286,
776,
211,
737,
727,
595,
407,
253,
775,
122,
115,
374,
571,
898,
450,
332,
844,
588,
270,
90,
3,
440,
119,
45,
715,
944,
671,
674,
649,
141,
819,
793,
429,
900,
683,
70,
935,
813,
818,
791,
956,
292,
213,
330,
887,
602,
899,
835,
87,
651,
428,
202,
841,
542,
467,
435,
938,
277,
624,
281,
661,
896,
731,
843,
32,
398,
129,
126,
341,
441,
820,
764,
80,
787,
952,
153,
506,
489,
941,
382,
430,
606,
464,
344,
851,
763,
462,
502,
161,
541,
16,
549,
997,
774,
503,
925,
457,
943,
625,
308,
847,
427,
263,
363,
54,
156,
937,
688,
420,
523,
754,
397,
546,
885,
722,
486,
260,
619,
538,
513,
532,
157,
558,
505,
367,
946,
573,
934,
76,
714,
647,
812,
723,
249,
20,
744,
250,
94,
103,
342,
251,
911,
632,
191,
670,
42,
978,
682,
859,
986,
92,
458,
91,
932,
735,
369,
252,
203,
314,
610,
957,
187,
265,
364,
871,
636,
220,
256,
114,
466,
324,
65,
993,
408,
320,
400,
988,
327,
908,
10,
27,
597,
962,
212,
929,
322,
488,
756,
617,
771,
555,
752,
445,
105,
710,
247,
974,
732,
118,
165,
922,
245,
759,
996,
608,
822,
801,
792,
858,
611,
46,
881,
283,
468,
975,
659,
577,
302,
827,
273,
83,
579,
229,
804,
584,
713,
739,
909,
117,
349,
718,
150,
389,
74,
416,
40,
724,
684,
800,
593,
668,
500,
618,
656,
298,
765,
348,
346,
376,
778,
740,
809,
921,
274,
958,
897,
170,
173,
136,
86,
41,
66,
240,
545,
967,
547,
783,
560,
985,
98,
969,
218,
439,
347,
138,
576,
335,
939,
160,
748,
288,
411,
626,
333,
889,
907,
823,
924,
977,
681,
106,
504,
198,
965,
976,
587,
831,
101,
355,
205,
387,
703,
521,
637,
175,
471,
826,
222,
604,
38,
832,
8,
133,
258,
578,
933,
162,
769,
317,
78,
833,
313,
48,
217,
128,
305,
60,
919,
646,
845,
328,
589,
691,
404,
961,
664,
536,
228,
461,
528,
711,
645,
225,
557,
830,
694,
518,
721,
970,
164,
736,
36,
149,
406,
18,
230,
21,
442,
620,
983,
522,
747,
259,
111,
264,
192,
431,
351,
395,
319,
24,
116,
485,
508,
329,
890,
465,
301,
918,
663,
279,
672,
861,
948,
799,
163,
171,
71,
297,
850,
189,
639,
112,
846,
255,
287,
773,
772,
14,
463,
17,
85,
72,
689,
33
]
}
================================================
FILE: datainfo/fire/data_split_dict_2.json
================================================
{
"test": [
465,
677,
921,
956,
851,
527,
512,
510,
285,
501,
255,
543,
670,
472,
756,
732,
409,
772,
165,
930,
879,
370,
362,
607,
808,
966,
781,
537,
752,
424,
815,
456,
915,
186,
947,
690,
526,
368,
938,
522,
139,
177,
332,
180,
24,
236,
241,
181,
573,
168,
538,
905,
913,
433,
389,
929,
326,
954,
476,
372,
28,
891,
979,
922,
274,
514,
455,
958,
557,
380,
521,
880,
740,
822,
402,
653,
441,
162,
697,
595,
36,
621,
217,
620,
257,
315,
874,
685,
828,
753,
173,
855,
369,
86,
93,
57,
869,
970,
883,
978
],
"val": [
4,
37,
155,
253,
44,
603,
394,
1,
708,
535,
188,
927,
160,
130,
261,
382,
21,
746,
41,
25,
69,
117,
84,
943,
688,
909,
176,
936,
371,
58,
32,
777,
734,
952,
61,
215,
250,
272,
939,
534,
916,
849,
698,
232,
605,
279,
50,
668,
588,
60,
108,
762,
195,
887,
8,
895,
349,
840,
803,
77,
638,
700,
375,
524,
500,
212,
748,
319,
416,
602,
630,
667,
519,
530,
575,
516,
903,
723,
818,
310,
316,
491,
791,
963,
631,
170,
990,
716,
834,
941,
227,
674,
498,
467,
741,
570,
743,
581,
359,
912
],
"train": [
290,
214,
112,
544,
695,
541,
792,
671,
658,
339,
298,
644,
820,
496,
92,
289,
206,
211,
997,
218,
856,
847,
988,
62,
844,
22,
559,
507,
266,
643,
793,
870,
866,
542,
432,
300,
471,
647,
99,
627,
64,
554,
379,
245,
152,
629,
276,
127,
470,
251,
810,
482,
308,
579,
560,
307,
928,
505,
683,
153,
711,
750,
357,
539,
66,
710,
701,
341,
513,
457,
146,
492,
111,
485,
348,
529,
574,
959,
469,
322,
759,
356,
552,
830,
43,
299,
120,
91,
398,
649,
948,
648,
968,
365,
645,
463,
443,
440,
682,
373,
865,
737,
733,
873,
662,
924,
216,
897,
651,
158,
454,
805,
839,
42,
660,
282,
2,
417,
745,
993,
894,
387,
229,
736,
821,
189,
198,
556,
106,
105,
899,
70,
517,
950,
466,
515,
867,
508,
576,
151,
807,
614,
9,
582,
680,
779,
228,
864,
249,
580,
831,
405,
10,
563,
964,
49,
488,
878,
687,
448,
871,
301,
140,
626,
434,
747,
378,
311,
284,
312,
377,
361,
911,
499,
624,
525,
80,
101,
654,
283,
715,
390,
87,
549,
295,
306,
325,
547,
461,
600,
26,
487,
817,
980,
785,
744,
795,
100,
854,
439,
532,
616,
969,
427,
798,
13,
901,
264,
410,
286,
63,
29,
945,
33,
892,
797,
355,
726,
618,
135,
742,
350,
133,
208,
273,
806,
294,
51,
243,
15,
138,
309,
178,
890,
7,
991,
35,
123,
661,
659,
328,
545,
219,
12,
376,
166,
985,
882,
269,
967,
406,
354,
6,
344,
56,
584,
555,
129,
331,
612,
156,
436,
438,
758,
729,
848,
340,
925,
55,
587,
231,
951,
197,
210,
601,
842,
205,
802,
450,
813,
782,
278,
976,
835,
242,
611,
479,
116,
351,
569,
59,
415,
974,
511,
692,
889,
794,
735,
90,
73,
565,
190,
172,
23,
122,
126,
923,
920,
767,
646,
423,
934,
919,
128,
169,
787,
597,
48,
914,
665,
693,
594,
858,
900,
145,
975,
199,
714,
385,
363,
277,
717,
223,
333,
931,
183,
933,
179,
53,
850,
65,
425,
568,
780,
622,
244,
225,
613,
446,
709,
452,
202,
421,
809,
800,
343,
955,
318,
836,
396,
770,
360,
827,
712,
804,
755,
684,
304,
623,
778,
175,
761,
193,
431,
281,
137,
419,
490,
71,
608,
94,
191,
347,
97,
566,
673,
845,
81,
167,
837,
110,
327,
397,
786,
453,
944,
149,
460,
572,
85,
679,
713,
504,
447,
182,
163,
109,
704,
194,
775,
637,
699,
994,
789,
142,
267,
641,
287,
234,
119,
705,
918,
45,
321,
258,
408,
76,
271,
305,
200,
672,
324,
632,
47,
430,
604,
201,
329,
334,
578,
407,
330,
884,
801,
606,
550,
132,
39,
751,
596,
696,
635,
247,
591,
888,
681,
776,
346,
445,
707,
226,
288,
957,
823,
615,
838,
187,
0,
239,
719,
617,
739,
79,
593,
486,
640,
875,
824,
546,
27,
819,
694,
442,
592,
403,
540,
506,
766,
207,
724,
480,
384,
388,
590,
518,
906,
314,
40,
14,
610,
983,
583,
136,
235,
171,
567,
881,
473,
248,
302,
386,
67,
656,
3,
358,
74,
209,
478,
940,
551,
493,
689,
853,
141,
908,
796,
577,
204,
586,
706,
220,
395,
221,
910,
17,
678,
946,
562,
896,
977,
320,
11,
391,
412,
192,
774,
83,
338,
876,
483,
666,
816,
414,
280,
494,
54,
937,
763,
184,
392,
30,
451,
477,
148,
367,
413,
240,
898,
161,
989,
263,
935,
254,
528,
336,
764,
720,
553,
159,
114,
293,
536,
825,
449,
337,
962,
399,
609,
131,
38,
88,
68,
426,
150,
558,
942,
364,
72,
893,
589,
599,
992,
437,
561,
783,
790,
811,
771,
634,
303,
503,
754,
489,
718,
495,
275,
383,
144,
468,
366,
297,
509,
857,
481,
721,
664,
691,
400,
342,
728,
125,
633,
260,
124,
404,
353,
749,
655,
401,
814,
585,
118,
196,
502,
224,
78,
663,
669,
768,
291,
5,
730,
628,
868,
213,
474,
497,
652,
420,
833,
462,
16,
203,
313,
702,
103,
422,
675,
788,
381,
296,
861,
625,
121,
531,
19,
374,
996,
96,
185,
841,
926,
428,
52,
727,
998,
902,
523,
435,
444,
852,
981,
953,
657,
548,
317,
971,
475,
986,
533,
877,
222,
107,
738,
932,
20,
102,
769,
246,
89,
862,
113,
164,
418,
393,
961,
154,
799,
949,
907,
832,
860,
262,
826,
859,
639,
233,
843,
82,
256,
75,
965,
731,
238,
252,
829,
725,
520,
237,
650,
464,
98,
174,
917,
134,
34,
259,
987,
686,
143,
571,
886,
18,
846,
429,
982,
268,
265,
885,
147,
335,
904,
960,
973,
598,
872,
812,
458,
972,
104,
323,
703,
676,
95,
230,
484,
157,
722,
636,
411,
773,
270,
46,
757,
619,
784,
564,
459,
984,
31,
863,
345,
292,
115,
765,
760,
642,
995,
352
]
}
================================================
FILE: datainfo/fire/data_split_dict_3.json
================================================
{
"test": [
712,
959,
987,
286,
876,
29,
698,
344,
968,
598,
417,
599,
546,
359,
587,
395,
853,
519,
431,
952,
875,
641,
797,
446,
688,
264,
222,
506,
139,
36,
99,
374,
943,
137,
455,
590,
820,
745,
404,
437,
807,
731,
396,
899,
942,
736,
609,
484,
275,
886,
843,
31,
798,
308,
43,
605,
776,
163,
65,
795,
687,
15,
759,
399,
535,
948,
888,
155,
650,
237,
154,
881,
654,
406,
487,
562,
856,
553,
481,
734,
196,
239,
564,
265,
480,
857,
930,
13,
620,
67,
594,
640,
485,
618,
937,
378,
133,
557,
606,
243
],
"val": [
603,
630,
830,
870,
322,
593,
866,
11,
835,
561,
310,
533,
924,
447,
918,
998,
732,
737,
649,
441,
277,
916,
635,
105,
572,
697,
945,
659,
914,
532,
471,
385,
859,
141,
368,
321,
347,
953,
706,
159,
269,
625,
298,
909,
202,
32,
548,
614,
110,
78,
7,
317,
928,
241,
517,
285,
981,
338,
600,
735,
386,
46,
779,
629,
619,
45,
121,
425,
787,
938,
300,
20,
969,
420,
68,
819,
352,
90,
495,
971,
874,
493,
64,
127,
291,
273,
913,
851,
648,
216,
671,
730,
106,
582,
585,
554,
334,
970,
715,
167
],
"train": [
316,
973,
262,
342,
904,
898,
566,
511,
871,
143,
176,
910,
644,
290,
927,
988,
624,
185,
370,
720,
156,
346,
388,
313,
616,
148,
891,
371,
993,
567,
832,
249,
272,
415,
529,
592,
472,
513,
749,
261,
664,
844,
848,
538,
236,
926,
847,
25,
777,
89,
147,
460,
873,
401,
219,
209,
115,
288,
684,
578,
826,
457,
174,
267,
576,
152,
432,
966,
362,
686,
520,
142,
746,
668,
145,
470,
98,
834,
709,
449,
908,
884,
192,
312,
889,
718,
71,
950,
690,
38,
974,
210,
704,
218,
421,
855,
380,
658,
536,
42,
303,
541,
762,
456,
207,
450,
402,
766,
579,
259,
551,
413,
412,
767,
982,
309,
198,
583,
166,
550,
915,
2,
922,
773,
393,
846,
179,
639,
217,
785,
814,
377,
713,
983,
500,
256,
390,
40,
479,
365,
355,
699,
733,
66,
476,
238,
463,
758,
491,
464,
114,
337,
328,
540,
925,
655,
893,
112,
738,
228,
466,
552,
279,
771,
53,
108,
199,
17,
865,
225,
100,
41,
701,
201,
742,
10,
907,
80,
911,
788,
178,
407,
315,
16,
627,
266,
190,
301,
8,
242,
661,
503,
235,
177,
443,
84,
504,
394,
790,
944,
253,
19,
622,
233,
521,
507,
6,
12,
809,
783,
646,
119,
725,
434,
168,
319,
144,
565,
113,
165,
366,
292,
9,
663,
251,
601,
350,
757,
414,
392,
756,
636,
645,
59,
462,
193,
197,
693,
0,
281,
186,
48,
314,
638,
405,
287,
220,
424,
555,
282,
794,
739,
364,
901,
162,
158,
900,
379,
208,
689,
985,
840,
324,
505,
103,
325,
716,
743,
672,
74,
957,
302,
213,
96,
545,
571,
990,
478,
442,
77,
837,
991,
526,
813,
633,
595,
206,
170,
770,
468,
559,
710,
518,
628,
861,
474,
760,
979,
138,
452,
863,
558,
363,
294,
490,
283,
92,
356,
408,
810,
824,
422,
234,
418,
153,
284,
215,
774,
769,
72,
902,
55,
63,
681,
768,
91,
967,
611,
391,
975,
150,
30,
180,
260,
252,
58,
839,
274,
289,
858,
890,
754,
805,
339,
729,
727,
188,
939,
748,
919,
486,
653,
764,
224,
57,
221,
793,
792,
140,
796,
956,
278,
896,
200,
574,
453,
667,
761,
39,
248,
194,
573,
169,
171,
960,
989,
172,
935,
842,
613,
231,
670,
933,
149,
135,
621,
60,
589,
373,
675,
965,
93,
610,
416,
86,
516,
575,
821,
958,
318,
509,
992,
931,
410,
97,
666,
596,
451,
674,
510,
397,
367,
331,
120,
660,
702,
23,
375,
400,
696,
862,
679,
246,
14,
719,
895,
872,
427,
703,
980,
651,
534,
607,
812,
977,
214,
570,
941,
483,
164,
955,
780,
612,
741,
528,
604,
411,
920,
932,
806,
157,
683,
497,
489,
482,
458,
304,
591,
868,
789,
343,
151,
962,
631,
502,
569,
118,
161,
263,
299,
75,
586,
996,
921,
680,
95,
439,
833,
869,
951,
496,
357,
469,
617,
563,
632,
49,
83,
76,
454,
459,
205,
740,
673,
685,
351,
811,
473,
345,
296,
387,
963,
744,
212,
82,
877,
721,
711,
211,
560,
498,
240,
543,
85,
577,
358,
972,
376,
652,
403,
465,
255,
525,
247,
79,
184,
332,
705,
102,
109,
662,
523,
37,
903,
306,
883,
880,
117,
26,
808,
878,
24,
183,
384,
101,
94,
800,
864,
954,
852,
146,
656,
5,
822,
544,
326,
111,
801,
724,
597,
707,
124,
268,
383,
333,
28,
995,
436,
816,
435,
608,
276,
845,
940,
676,
18,
885,
428,
430,
765,
508,
381,
717,
818,
44,
804,
894,
829,
382,
836,
182,
580,
978,
389,
3,
568,
130,
126,
438,
336,
867,
271,
369,
311,
984,
398,
827,
912,
73,
722,
160,
323,
784,
514,
527,
88,
923,
494,
70,
897,
882,
47,
129,
56,
1,
330,
946,
52,
433,
828,
539,
751,
254,
708,
51,
409,
584,
499,
849,
131,
280,
245,
815,
677,
547,
125,
949,
226,
360,
781,
747,
976,
258,
128,
682,
116,
626,
54,
905,
549,
537,
802,
750,
763,
515,
175,
726,
492,
986,
50,
934,
440,
203,
524,
35,
860,
250,
122,
293,
602,
69,
232,
349,
556,
295,
531,
678,
775,
33,
753,
823,
444,
522,
637,
488,
4,
204,
838,
423,
964,
419,
581,
429,
297,
426,
817,
354,
475,
728,
530,
841,
917,
132,
782,
353,
634,
87,
657,
348,
786,
187,
588,
803,
195,
542,
34,
123,
230,
879,
461,
961,
223,
936,
906,
81,
173,
448,
229,
691,
341,
329,
772,
104,
929,
714,
665,
445,
694,
191,
335,
244,
947,
669,
227,
512,
850,
134,
752,
700,
892,
27,
107,
831,
307,
270,
825,
778,
320,
189,
372,
181,
327,
615,
997,
305,
467,
643,
257,
994,
21,
692,
62,
799,
22,
501,
755,
854,
723,
623,
791,
695,
361,
477,
340,
642,
887,
61,
136,
647
]
}
================================================
FILE: datainfo/lava_lamp/data_split_dict_1.json
================================================
{
"test": [
408,
380,
248,
491,
35,
402,
511,
290,
310,
194,
519,
539,
512,
340,
123,
303,
190,
102,
426,
544,
296,
470,
224,
236,
353,
238,
561,
448,
227,
554,
432,
221,
390,
9,
26,
22,
31,
325,
104,
234,
272,
456,
2,
443,
399,
29,
499,
96,
214,
388,
483,
460,
507,
120,
261,
64,
137
],
"val": [
463,
337,
565,
235,
180,
295,
433,
176,
263,
207,
118,
503,
440,
276,
526,
394,
6,
116,
257,
86,
87,
331,
487,
529,
524,
314,
434,
360,
157,
528,
240,
15,
375,
250,
189,
201,
430,
266,
83,
398,
55,
260,
339,
531,
44,
191,
377,
345,
397,
359,
451,
280,
187,
88,
522,
212,
206
],
"train": [
126,
245,
516,
84,
527,
351,
124,
165,
462,
504,
393,
447,
352,
54,
385,
251,
155,
52,
327,
160,
244,
404,
498,
453,
392,
67,
396,
131,
76,
21,
475,
163,
365,
274,
154,
422,
110,
413,
128,
558,
217,
184,
241,
362,
17,
195,
464,
279,
326,
147,
89,
249,
496,
216,
59,
410,
342,
103,
106,
320,
73,
505,
378,
381,
457,
334,
142,
318,
72,
530,
341,
292,
482,
545,
304,
253,
537,
161,
38,
98,
239,
523,
79,
95,
316,
284,
145,
192,
285,
172,
135,
247,
409,
13,
324,
555,
478,
328,
209,
374,
133,
65,
119,
357,
113,
298,
162,
368,
233,
489,
60,
461,
45,
101,
306,
370,
1,
223,
277,
121,
226,
125,
497,
220,
371,
288,
51,
552,
5,
541,
134,
495,
185,
267,
62,
77,
550,
47,
490,
178,
536,
407,
335,
494,
210,
506,
481,
71,
348,
406,
63,
476,
78,
181,
367,
405,
317,
91,
256,
486,
333,
361,
458,
435,
469,
153,
308,
168,
418,
286,
183,
100,
198,
179,
343,
450,
449,
468,
508,
321,
502,
297,
25,
315,
122,
425,
329,
500,
376,
138,
415,
312,
301,
525,
40,
204,
560,
553,
354,
208,
41,
358,
532,
338,
61,
48,
421,
444,
222,
417,
23,
557,
171,
237,
3,
225,
428,
11,
243,
27,
70,
412,
534,
452,
229,
75,
93,
269,
271,
389,
57,
140,
146,
442,
363,
200,
484,
480,
538,
309,
32,
424,
37,
474,
300,
97,
350,
547,
355,
349,
543,
366,
141,
467,
34,
485,
445,
149,
369,
174,
431,
188,
382,
347,
170,
391,
509,
384,
510,
68,
43,
20,
33,
562,
549,
427,
273,
387,
518,
49,
136,
109,
219,
173,
69,
167,
305,
80,
533,
144,
205,
166,
439,
255,
429,
151,
199,
53,
252,
99,
356,
293,
437,
50,
423,
488,
193,
477,
493,
472,
455,
520,
111,
302,
19,
446,
4,
441,
479,
289,
213,
383,
330,
158,
39,
400,
156,
24,
108,
564,
152,
30,
346,
323,
372,
294,
202,
559,
332,
268,
114,
230,
264,
322,
112,
278,
436,
259,
228,
82,
18,
74,
203,
542,
115,
10,
540,
517,
107,
563,
129,
492,
132,
556,
215,
175,
197,
159,
12,
58,
242,
254,
164,
501,
232,
150,
364,
473,
139,
336,
471,
270,
403,
81,
85,
513,
148,
459,
94,
419,
56,
454,
127,
143,
395,
386,
7,
231,
8,
42,
36,
344,
16,
130,
282,
46,
92,
299,
281,
90,
546,
117,
411,
14,
307,
548,
169,
313,
514,
319,
465,
275,
0,
177,
535,
182,
416,
515,
211,
258,
466,
283,
291,
186,
246,
28,
218,
105,
287,
521,
265,
66,
414,
262,
379,
420,
438,
401,
196,
551,
373,
311
]
}
================================================
FILE: datainfo/lava_lamp/data_split_dict_2.json
================================================
{
"test": [
462,
460,
465,
523,
512,
510,
285,
501,
255,
472,
409,
165,
528,
370,
362,
546,
424,
456,
186,
526,
368,
531,
522,
139,
177,
332,
180,
24,
236,
241,
181,
168,
538,
433,
389,
326,
476,
372,
28,
557,
274,
514,
455,
380,
521,
402,
441,
162,
36,
217,
257,
315,
173,
369,
86,
93,
57
],
"val": [
30,
54,
381,
97,
495,
4,
505,
174,
420,
401,
38,
451,
319,
350,
187,
262,
250,
106,
374,
159,
208,
301,
489,
333,
259,
265,
287,
258,
425,
361,
519,
155,
158,
245,
466,
395,
137,
560,
464,
448,
85,
427,
358,
417,
166,
481,
113,
337,
249,
233,
530,
515,
471,
371,
290,
179,
537
],
"train": [
89,
278,
542,
391,
303,
504,
223,
215,
320,
114,
559,
327,
377,
331,
488,
336,
76,
555,
14,
206,
539,
81,
96,
507,
153,
124,
150,
225,
470,
399,
284,
50,
307,
5,
33,
203,
509,
199,
541,
445,
341,
457,
218,
239,
273,
83,
335,
246,
19,
556,
322,
393,
224,
475,
387,
340,
385,
463,
352,
169,
357,
133,
252,
483,
90,
384,
230,
170,
508,
305,
411,
207,
244,
347,
296,
3,
79,
143,
421,
440,
423,
379,
189,
275,
221,
256,
408,
175,
243,
517,
458,
355,
121,
163,
183,
61,
55,
563,
272,
1,
188,
397,
110,
62,
228,
392,
182,
529,
254,
449,
102,
151,
482,
104,
375,
478,
498,
365,
443,
346,
324,
72,
312,
27,
141,
279,
496,
234,
253,
154,
160,
413,
410,
551,
144,
291,
439,
487,
506,
364,
292,
516,
105,
64,
235,
70,
220,
13,
415,
240,
46,
545,
7,
400,
339,
264,
394,
193,
103,
286,
311,
518,
43,
271,
329,
63,
454,
238,
195,
406,
138,
396,
547,
363,
219,
317,
297,
84,
407,
171,
398,
99,
280,
32,
281,
520,
212,
405,
511,
544,
293,
491,
45,
226,
323,
536,
147,
328,
149,
21,
202,
561,
178,
109,
140,
66,
152,
213,
40,
479,
342,
75,
513,
438,
300,
383,
527,
330,
122,
360,
316,
68,
95,
204,
548,
117,
91,
74,
426,
535,
429,
11,
198,
120,
540,
304,
6,
295,
533,
444,
100,
164,
492,
298,
432,
419,
247,
366,
502,
31,
412,
276,
493,
192,
35,
270,
200,
416,
59,
98,
251,
112,
39,
277,
283,
308,
145,
480,
499,
469,
306,
248,
210,
435,
231,
8,
101,
156,
359,
314,
211,
288,
390,
190,
148,
289,
60,
562,
356,
486,
485,
48,
92,
268,
494,
26,
313,
261,
558,
222,
436,
108,
194,
549,
484,
345,
237,
266,
524,
111,
53,
353,
564,
553,
51,
418,
123,
44,
467,
56,
348,
209,
196,
554,
403,
461,
269,
142,
434,
131,
428,
534,
490,
446,
447,
41,
128,
37,
227,
119,
404,
431,
260,
118,
325,
232,
49,
87,
82,
67,
17,
129,
430,
343,
71,
503,
9,
450,
214,
310,
134,
132,
459,
73,
167,
263,
500,
201,
299,
477,
414,
543,
550,
52,
161,
351,
338,
47,
115,
242,
78,
497,
318,
205,
135,
23,
378,
309,
282,
229,
157,
15,
468,
172,
146,
565,
382,
552,
321,
474,
176,
2,
18,
77,
126,
22,
473,
197,
0,
354,
442,
94,
376,
80,
65,
130,
191,
10,
373,
20,
525,
34,
58,
42,
12,
344,
127,
88,
184,
185,
29,
16,
388,
367,
216,
452,
107,
422,
125,
136,
437,
69,
267,
386,
453,
349,
116,
302,
532,
25,
334,
294
]
}
================================================
FILE: datainfo/lava_lamp/data_split_dict_3.json
================================================
{
"test": [
425,
324,
216,
106,
334,
167,
286,
29,
344,
549,
417,
359,
395,
519,
431,
541,
446,
264,
222,
506,
139,
36,
99,
374,
137,
455,
404,
437,
396,
484,
275,
31,
308,
43,
163,
65,
15,
399,
535,
155,
237,
154,
406,
487,
553,
481,
196,
239,
265,
480,
13,
67,
485,
378,
133,
557,
243
],
"val": [
173,
444,
21,
353,
79,
134,
312,
149,
208,
101,
16,
274,
307,
55,
39,
3,
158,
18,
120,
258,
142,
472,
451,
282,
169,
300,
367,
193,
23,
389,
314,
309,
22,
60,
525,
212,
393,
218,
150,
10,
77,
459,
210,
34,
409,
176,
45,
247,
327,
536,
246,
32,
63,
145,
136,
293,
505
],
"train": [
188,
28,
326,
171,
483,
422,
494,
319,
440,
305,
59,
355,
379,
299,
73,
383,
178,
54,
376,
331,
352,
454,
234,
421,
112,
42,
442,
529,
354,
445,
342,
364,
46,
187,
565,
126,
103,
436,
181,
104,
199,
35,
41,
450,
287,
191,
38,
449,
427,
539,
156,
526,
358,
270,
382,
517,
313,
388,
325,
85,
24,
555,
458,
242,
402,
466,
426,
272,
423,
37,
411,
443,
227,
292,
100,
175,
273,
298,
322,
470,
75,
56,
335,
91,
504,
551,
424,
384,
51,
119,
19,
231,
141,
500,
372,
74,
229,
88,
512,
205,
168,
520,
283,
279,
511,
202,
375,
71,
44,
269,
195,
47,
419,
50,
297,
351,
78,
276,
263,
248,
124,
410,
252,
105,
302,
157,
394,
476,
545,
435,
380,
509,
130,
9,
232,
76,
72,
236,
516,
362,
385,
387,
288,
206,
337,
471,
254,
336,
253,
194,
221,
561,
530,
260,
1,
330,
26,
528,
144,
81,
277,
96,
562,
524,
118,
320,
217,
143,
190,
83,
7,
33,
318,
57,
251,
225,
151,
69,
255,
465,
267,
4,
200,
460,
390,
107,
226,
179,
369,
381,
201,
398,
110,
159,
240,
162,
207,
284,
166,
428,
117,
363,
281,
498,
508,
20,
203,
338,
185,
183,
245,
478,
241,
556,
109,
84,
48,
371,
92,
492,
490,
531,
165,
98,
89,
414,
228,
349,
80,
295,
391,
131,
538,
182,
489,
8,
123,
507,
12,
339,
397,
503,
499,
14,
219,
0,
316,
198,
82,
429,
532,
108,
563,
457,
370,
430,
127,
412,
289,
204,
249,
467,
544,
140,
475,
285,
356,
62,
534,
306,
420,
521,
310,
129,
64,
477,
58,
27,
482,
463,
268,
488,
365,
546,
257,
87,
418,
502,
6,
25,
433,
461,
262,
340,
558,
125,
341,
146,
495,
116,
333,
278,
147,
496,
554,
462,
523,
439,
401,
261,
244,
2,
102,
456,
211,
469,
209,
290,
214,
148,
213,
177,
518,
343,
564,
493,
215,
66,
537,
497,
501,
542,
328,
174,
400,
93,
294,
522,
97,
271,
17,
61,
115,
434,
230,
373,
111,
360,
172,
40,
86,
224,
114,
345,
407,
164,
386,
438,
49,
357,
332,
527,
547,
95,
514,
122,
533,
513,
113,
256,
468,
560,
350,
291,
559,
53,
447,
153,
135,
448,
392,
474,
94,
186,
90,
543,
464,
303,
152,
233,
408,
128,
189,
416,
346,
540,
413,
11,
250,
377,
473,
361,
311,
405,
347,
180,
238,
170,
321,
432,
30,
68,
323,
301,
315,
486,
491,
161,
296,
552,
403,
5,
452,
280,
548,
453,
132,
223,
550,
121,
366,
368,
510,
220,
138,
259,
415,
317,
52,
515,
348,
304,
329,
197,
266,
235,
192,
479,
441,
70,
184,
160
]
}
================================================
FILE: datainfo/reaction_diffusion/data_split_dict_1.json
================================================
{
"test": [
83,
60,
57,
63,
15,
32,
8,
97,
72,
17
],
"val": [
92,
0,
77,
55,
49,
3,
62,
12,
26,
48
],
"train": [
53,
37,
65,
51,
4,
20,
38,
9,
10,
81,
44,
36,
84,
50,
96,
90,
66,
16,
80,
33,
24,
52,
91,
99,
64,
5,
58,
76,
39,
79,
23,
94,
30,
73,
25,
47,
31,
45,
19,
87,
42,
68,
95,
21,
7,
67,
46,
82,
11,
6,
41,
86,
88,
70,
18,
78,
71,
59,
43,
61,
22,
14,
35,
93,
56,
28,
98,
54,
27,
89,
1,
69,
74,
2,
85,
40,
13,
75,
29,
34
]
}
================================================
FILE: datainfo/reaction_diffusion/data_split_dict_2.json
================================================
{
"test": [
77,
32,
39,
85,
94,
21,
46,
10,
11,
7
],
"val": [
47,
65,
50,
81,
55,
20,
74,
4,
90,
27
],
"train": [
0,
76,
61,
63,
1,
71,
2,
6,
16,
19,
13,
24,
49,
12,
75,
9,
83,
72,
5,
41,
99,
45,
89,
53,
79,
18,
52,
92,
14,
42,
68,
44,
38,
84,
36,
17,
31,
15,
70,
88,
25,
97,
51,
73,
66,
37,
78,
33,
80,
26,
82,
28,
60,
35,
43,
57,
23,
58,
91,
8,
62,
93,
98,
86,
29,
30,
22,
95,
67,
54,
48,
40,
59,
96,
3,
87,
34,
64,
56,
69
]
}
================================================
FILE: datainfo/reaction_diffusion/data_split_dict_3.json
================================================
{
"test": [
8,
74,
80,
60,
77,
47,
16,
69,
75,
30
],
"val": [
85,
97,
87,
24,
29,
70,
33,
93,
1,
94
],
"train": [
35,
41,
45,
4,
76,
12,
0,
64,
42,
11,
81,
39,
7,
48,
78,
15,
53,
62,
9,
52,
68,
55,
63,
65,
61,
67,
95,
18,
58,
86,
99,
92,
10,
17,
72,
21,
14,
44,
37,
91,
22,
59,
83,
32,
26,
98,
40,
27,
43,
96,
13,
31,
57,
2,
6,
23,
56,
71,
28,
36,
51,
46,
25,
54,
73,
79,
34,
3,
38,
5,
20,
90,
88,
49,
66,
89,
84,
19,
50,
82
]
}
================================================
FILE: datainfo/single_pendulum/data_split_dict_1.json
================================================
{
"test": [
187,
370,
362,
470,
57,
938,
678,
3,
708,
1137,
730,
993,
846,
1033,
409,
746,
985,
114,
872,
420,
1062,
264,
1049,
785,
11,
551,
940,
723,
704,
1052,
828,
475,
1105,
408,
25,
464,
1028,
345,
348,
806,
631,
89,
961,
60,
1002,
758,
1142,
1066,
335,
221,
1041,
898,
177,
767,
1123,
751,
354,
848,
827,
497,
983,
70,
805,
1034,
1022,
581,
621,
388,
1039,
1171,
1025,
681,
247,
607,
380,
204,
1139,
852,
44,
593,
941,
448,
472,
707,
477,
1132,
1015,
896,
454,
1080,
59,
864,
443,
780,
18,
1108,
52,
45,
62,
650,
209,
468,
545,
912,
4,
886,
798,
58,
999,
192,
429,
777,
967,
920,
1014,
241,
522,
129,
1165,
275
],
"val": [
411,
892,
626,
333,
17,
1071,
516,
303,
399,
1151,
960,
106,
504,
198,
605,
1172,
918,
690,
587,
210,
101,
355,
205,
387,
1000,
521,
637,
720,
1186,
890,
888,
847,
175,
471,
583,
922,
1096,
1046,
839,
604,
38,
870,
899,
574,
8,
133,
258,
578,
426,
162,
761,
305,
1122,
939,
317,
78,
879,
1036,
313,
1053,
1167,
217,
991,
257,
611,
120,
1093,
657,
808,
1178,
457,
923,
451,
873,
1183,
328,
72,
299,
813,
36,
461,
42,
884,
428,
1044,
519,
222,
529,
385,
862,
703,
791,
638,
48,
233,
970,
1016,
659,
931,
603,
558,
344,
1079,
326,
342,
142,
594,
705,
378,
224,
550,
511,
575,
29,
927,
34,
170,
144,
66,
1196
],
"train": [
483,
489,
341,
685,
96,
1,
1148,
542,
656,
744,
1085,
613,
855,
775,
680,
111,
1081,
648,
308,
733,
1023,
189,
759,
61,
482,
639,
1190,
228,
776,
596,
885,
1024,
1187,
1048,
295,
1158,
1146,
1094,
819,
684,
540,
793,
856,
1099,
622,
262,
773,
609,
107,
263,
260,
1160,
958,
474,
672,
739,
442,
478,
887,
1083,
734,
612,
1084,
956,
843,
55,
1067,
750,
823,
9,
532,
112,
282,
329,
520,
874,
359,
1125,
1102,
514,
668,
1164,
459,
197,
844,
753,
635,
633,
185,
1195,
966,
201,
906,
1059,
569,
146,
580,
39,
1047,
446,
597,
276,
43,
1042,
533,
561,
166,
1003,
154,
778,
647,
243,
67,
768,
617,
1149,
634,
981,
649,
450,
507,
215,
1057,
740,
691,
710,
1120,
858,
889,
682,
71,
531,
869,
415,
849,
444,
237,
716,
214,
764,
735,
765,
595,
486,
90,
834,
989,
437,
292,
1127,
549,
330,
16,
304,
559,
1168,
104,
186,
916,
719,
99,
623,
476,
293,
498,
286,
1136,
905,
503,
143,
952,
917,
167,
242,
26,
797,
149,
315,
652,
897,
871,
234,
271,
223,
1029,
769,
238,
673,
924,
294,
1199,
567,
33,
207,
821,
945,
795,
895,
548,
784,
729,
987,
412,
908,
986,
965,
375,
427,
6,
891,
456,
195,
980,
641,
435,
419,
1004,
925,
721,
538,
80,
718,
762,
679,
1063,
131,
937,
311,
307,
552,
485,
1054,
1043,
959,
1176,
840,
1037,
432,
699,
1177,
812,
424,
1157,
752,
360,
190,
68,
316,
693,
1194,
383,
384,
951,
517,
741,
1110,
1072,
671,
676,
219,
179,
1121,
30,
396,
1106,
1005,
585,
402,
321,
954,
943,
606,
749,
255,
859,
530,
865,
390,
395,
976,
877,
556,
268,
1007,
755,
957,
909,
702,
933,
509,
929,
130,
582,
289,
1073,
277,
756,
973,
1100,
394,
525,
81,
598,
100,
147,
379,
196,
644,
851,
2,
782,
351,
824,
1090,
56,
995,
586,
636,
919,
660,
108,
199,
1045,
280,
1091,
1107,
553,
757,
421,
193,
655,
653,
119,
494,
492,
1019,
675,
343,
1134,
140,
339,
692,
453,
1018,
438,
1163,
964,
571,
695,
640,
28,
696,
95,
51,
510,
164,
493,
910,
54,
1188,
156,
911,
998,
1170,
123,
838,
158,
184,
760,
401,
371,
1141,
974,
269,
235,
134,
770,
75,
372,
747,
227,
296,
1032,
374,
1173,
487,
97,
434,
127,
893,
1129,
841,
206,
591,
15,
338,
125,
361,
619,
694,
113,
366,
599,
413,
469,
404,
842,
88,
1155,
787,
200,
1092,
1068,
738,
714,
463,
630,
543,
913,
334,
152,
481,
1114,
645,
285,
462,
854,
188,
337,
267,
425,
357,
984,
717,
137,
1031,
121,
947,
270,
1131,
473,
484,
1086,
1013,
836,
570,
47,
353,
82,
1088,
832,
84,
501,
226,
336,
1075,
666,
35,
216,
781,
508,
50,
465,
491,
23,
763,
669,
926,
674,
132,
687,
236,
77,
155,
278,
350,
1058,
715,
1082,
534,
135,
467,
178,
5,
309,
397,
417,
37,
754,
670,
433,
902,
815,
524,
479,
544,
664,
79,
701,
789,
528,
748,
829,
148,
332,
495,
904,
539,
381,
745,
790,
69,
1061,
414,
624,
398,
727,
368,
202,
537,
174,
535,
590,
1162,
1051,
643,
978,
1087,
642,
602,
915,
115,
972,
779,
1038,
731,
358,
816,
452,
850,
0,
331,
1145,
833,
515,
447,
568,
382,
663,
1135,
139,
1192,
13,
589,
689,
867,
290,
1180,
392,
213,
124,
430,
502,
365,
722,
172,
935,
651,
318,
279,
955,
950,
151,
291,
736,
724,
266,
248,
31,
21,
883,
665,
194,
1184,
572,
837,
254,
284,
820,
1150,
418,
546,
441,
122,
1021,
499,
725,
662,
963,
809,
406,
712,
391,
246,
455,
1124,
1065,
85,
875,
928,
914,
783,
536,
1074,
766,
168,
527,
393,
356,
403,
1027,
53,
994,
997,
436,
700,
159,
180,
386,
860,
831,
620,
126,
87,
608,
1111,
169,
1154,
565,
713,
1011,
319,
208,
646,
163,
377,
523,
73,
709,
1069,
281,
625,
573,
1117,
422,
230,
490,
506,
814,
244,
231,
654,
1161,
512,
772,
541,
181,
7,
944,
239,
932,
627,
176,
410,
1098,
407,
868,
802,
737,
907,
1153,
141,
901,
363,
63,
19,
880,
1008,
661,
24,
1156,
1101,
992,
1143,
405,
1104,
1115,
711,
449,
794,
1119,
562,
440,
1030,
698,
811,
1006,
253,
683,
49,
161,
1070,
975,
306,
182,
979,
706,
566,
688,
1109,
323,
32,
1060,
145,
211,
1197,
300,
616,
526,
726,
109,
312,
817,
1077,
153,
183,
1198,
969,
996,
881,
774,
513,
1133,
157,
799,
505,
367,
297,
822,
1179,
22,
76,
1095,
1017,
900,
1193,
894,
1055,
249,
20,
225,
250,
94,
818,
1103,
962,
557,
103,
1064,
251,
310,
592,
810,
191,
953,
1130,
792,
968,
232,
948,
658,
588,
826,
92,
458,
771,
91,
287,
876,
369,
252,
203,
314,
1138,
554,
1169,
265,
364,
677,
480,
1175,
835,
796,
632,
803,
220,
256,
1097,
466,
615,
324,
65,
64,
1113,
320,
400,
460,
327,
610,
743,
863,
866,
10,
27,
853,
325,
667,
212,
102,
322,
488,
728,
259,
1056,
301,
555,
825,
882,
445,
105,
1010,
1009,
1152,
697,
171,
1040,
118,
165,
431,
600,
804,
245,
1189,
1020,
1116,
845,
1001,
423,
93,
14,
686,
628,
12,
1026,
1144,
46,
1126,
110,
283,
1181,
1012,
934,
577,
302,
373,
273,
83,
579,
229,
563,
584,
1166,
1140,
800,
601,
629,
117,
349,
128,
1089,
150,
807,
1185,
389,
74,
416,
40,
1035,
971,
788,
564,
1159,
949,
500,
732,
988,
618,
990,
930,
298,
116,
1118,
346,
376,
261,
861,
518,
614,
340,
1191,
274,
946,
1112,
1076,
173,
136,
86,
41,
742,
1078,
240,
1182,
786,
496,
547,
1050,
942,
903,
936,
352,
560,
1147,
857,
98,
977,
272,
218,
439,
347,
138,
801,
576,
830,
1128,
878,
982,
160,
1174,
288,
921
]
}
================================================
FILE: datainfo/single_pendulum/data_split_dict_2.json
================================================
{
"test": [
789,
3,
1071,
376,
321,
261,
523,
764,
43,
83,
51,
138,
235,
169,
1167,
510,
352,
737,
742,
116,
65,
866,
123,
431,
501,
544,
1163,
1069,
1113,
464,
559,
100,
120,
217,
391,
17,
699,
154,
750,
1048,
1001,
425,
638,
832,
1039,
1060,
1032,
621,
633,
982,
1181,
340,
664,
454,
996,
935,
718,
1171,
931,
1152,
1055,
1025,
1020,
571,
1003,
511,
1086,
944,
818,
330,
1156,
741,
724,
1178,
1075,
849,
912,
372,
1146,
1052,
736,
1162,
1044,
279,
355,
665,
361,
48,
472,
483,
363,
1147,
336,
1076,
867,
778,
652,
952,
745,
56,
1191,
549,
1028,
911,
1114,
761,
1042,
805,
882,
324,
1190,
73,
434,
515,
631,
346,
739,
173,
187,
115
],
"val": [
954,
706,
296,
375,
625,
121,
943,
531,
19,
374,
491,
821,
679,
96,
942,
185,
983,
537,
951,
428,
902,
52,
605,
595,
21,
1158,
435,
444,
825,
746,
215,
986,
1175,
1037,
701,
1108,
389,
657,
548,
317,
1109,
475,
685,
533,
25,
222,
107,
237,
768,
186,
20,
102,
104,
246,
89,
1151,
524,
113,
1029,
418,
393,
1046,
1117,
9,
570,
1101,
525,
1160,
467,
164,
513,
150,
910,
476,
505,
64,
474,
928,
196,
349,
1149,
269,
68,
518,
1099,
287,
36,
859,
536,
530,
698,
294,
671,
997,
804,
1085,
917,
49,
208,
647,
191,
461,
968,
314,
822,
540,
93,
918,
1194,
63,
1000,
690,
585,
231,
704,
8,
74,
310,
507,
88
],
"train": [
689,
907,
403,
285,
1073,
796,
316,
140,
1183,
132,
861,
1047,
167,
556,
985,
999,
1185,
891,
984,
199,
717,
40,
937,
354,
405,
271,
1155,
1125,
262,
900,
292,
304,
151,
909,
397,
580,
820,
646,
667,
1057,
183,
684,
760,
450,
1159,
726,
198,
747,
880,
1012,
335,
719,
39,
913,
84,
168,
207,
1051,
396,
618,
1088,
834,
567,
146,
923,
814,
921,
629,
299,
1139,
45,
734,
216,
758,
302,
424,
128,
60,
406,
85,
447,
205,
710,
547,
566,
32,
156,
18,
892,
105,
1170,
239,
659,
1061,
797,
76,
248,
541,
641,
783,
969,
1199,
508,
94,
58,
394,
639,
863,
479,
819,
1034,
487,
947,
862,
283,
305,
327,
438,
744,
975,
386,
723,
202,
756,
979,
236,
1092,
1094,
774,
343,
219,
920,
360,
945,
601,
810,
998,
568,
932,
197,
603,
830,
527,
808,
940,
46,
792,
875,
816,
1005,
506,
916,
634,
976,
244,
1009,
2,
1198,
290,
766,
1182,
840,
417,
182,
249,
210,
443,
623,
307,
590,
1043,
592,
925,
155,
858,
675,
90,
427,
442,
848,
846,
1093,
692,
1030,
1131,
35,
255,
111,
1166,
469,
1016,
22,
614,
377,
1018,
26,
195,
499,
165,
37,
333,
883,
899,
247,
906,
981,
288,
365,
616,
693,
1070,
7,
1135,
441,
1179,
874,
188,
992,
851,
347,
1054,
456,
258,
991,
529,
1110,
1066,
970,
69,
370,
457,
735,
1137,
546,
1050,
13,
879,
312,
348,
872,
23,
226,
611,
1130,
678,
445,
57,
557,
569,
853,
668,
532,
272,
806,
379,
977,
974,
223,
266,
385,
127,
971,
31,
562,
1112,
1058,
624,
753,
798,
1161,
300,
281,
980,
552,
253,
42,
142,
228,
486,
573,
436,
1103,
407,
459,
781,
1023,
1142,
780,
1077,
225,
175,
517,
714,
163,
29,
1157,
622,
14,
6,
961,
473,
274,
660,
157,
329,
295,
1017,
241,
617,
1065,
1021,
158,
731,
828,
811,
1064,
227,
378,
888,
565,
1063,
586,
1116,
411,
1098,
707,
232,
229,
87,
1100,
423,
833,
1041,
838,
1134,
612,
133,
896,
214,
109,
578,
946,
887,
645,
398,
1121,
539,
757,
670,
212,
211,
687,
572,
865,
1196,
700,
80,
584,
149,
588,
662,
351,
1136,
972,
542,
708,
325,
521,
455,
1111,
1096,
957,
432,
306,
786,
655,
1033,
458,
591,
850,
1193,
134,
380,
1174,
1148,
1144,
1040,
1010,
962,
200,
99,
683,
876,
384,
282,
308,
482,
1192,
1123,
857,
357,
894,
583,
669,
117,
615,
250,
785,
466,
564,
139,
575,
1090,
607,
50,
415,
1053,
1097,
868,
950,
66,
498,
1180,
1118,
174,
152,
953,
193,
967,
898,
1084,
77,
1019,
847,
53,
705,
776,
345,
649,
1056,
268,
581,
201,
490,
331,
318,
635,
619,
672,
519,
598,
27,
632,
108,
492,
534,
126,
59,
276,
812,
0,
251,
855,
484,
122,
448,
270,
1124,
596,
267,
416,
941,
835,
949,
964,
749,
930,
800,
277,
864,
356,
70,
15,
286,
71,
1013,
33,
1102,
1127,
460,
439,
897,
613,
332,
419,
172,
784,
504,
703,
933,
12,
844,
112,
606,
772,
948,
129,
4,
593,
408,
234,
993,
666,
637,
179,
512,
965,
839,
966,
110,
676,
463,
751,
421,
815,
465,
410,
1078,
794,
92,
1189,
1015,
485,
1067,
955,
233,
597,
1195,
119,
1059,
576,
135,
278,
1133,
440,
1035,
759,
147,
1107,
929,
344,
47,
245,
252,
914,
608,
654,
256,
369,
430,
885,
257,
339,
769,
145,
97,
680,
206,
323,
807,
1022,
752,
1138,
554,
787,
242,
604,
555,
738,
446,
721,
809,
171,
359,
106,
711,
130,
642,
322,
190,
1014,
488,
802,
362,
673,
178,
677,
904,
653,
1083,
788,
1105,
630,
799,
328,
775,
713,
1045,
350,
409,
387,
563,
1188,
651,
702,
915,
1122,
189,
919,
194,
480,
243,
101,
162,
334,
1106,
755,
905,
137,
813,
298,
452,
889,
170,
371,
767,
716,
79,
852,
326,
218,
373,
388,
644,
827,
62,
284,
535,
836,
574,
886,
238,
41,
681,
91,
643,
516,
153,
543,
1141,
901,
990,
648,
520,
95,
1049,
893,
658,
319,
661,
341,
311,
765,
264,
926,
309,
551,
881,
136,
1132,
1095,
958,
762,
453,
577,
1008,
550,
34,
963,
1,
729,
1024,
289,
10,
826,
301,
1074,
1087,
390,
1145,
166,
55,
1091,
640,
691,
1140,
790,
779,
1164,
934,
81,
28,
471,
514,
273,
470,
903,
1082,
496,
1129,
67,
777,
1002,
824,
939,
1081,
358,
1173,
209,
478,
177,
579,
1026,
1080,
493,
987,
694,
627,
1154,
141,
626,
1104,
922,
594,
600,
791,
204,
1143,
873,
1011,
220,
395,
620,
582,
1187,
221,
770,
1115,
973,
181,
1165,
803,
728,
842,
1120,
320,
11,
650,
1184,
682,
412,
192,
636,
545,
230,
743,
1089,
956,
763,
730,
338,
560,
368,
1169,
841,
890,
960,
748,
727,
259,
414,
280,
494,
54,
522,
402,
184,
754,
392,
609,
30,
878,
451,
477,
148,
989,
587,
782,
801,
367,
413,
240,
720,
1168,
161,
1197,
263,
526,
732,
254,
528,
1172,
854,
433,
924,
553,
159,
114,
293,
1119,
176,
449,
337,
843,
686,
725,
869,
399,
871,
131,
38,
663,
829,
697,
1079,
884,
1186,
426,
988,
1031,
558,
180,
364,
72,
98,
589,
837,
877,
599,
86,
715,
695,
712,
437,
561,
265,
610,
500,
160,
1007,
1150,
303,
696,
856,
503,
429,
823,
1153,
1027,
489,
538,
495,
275,
383,
817,
144,
468,
366,
297,
509,
722,
927,
44,
795,
481,
674,
1128,
1004,
1177,
709,
995,
400,
656,
342,
870,
1068,
82,
125,
936,
260,
124,
908,
404,
353,
771,
143,
938,
733,
401,
382,
1072,
118,
1038,
502,
773,
224,
78,
740,
978,
959,
24,
291,
5,
75,
628,
602,
1126,
213,
1036,
497,
1176,
420,
61,
462,
831,
16,
845,
688,
793,
860,
203,
313,
1006,
103,
422,
994,
895,
1062,
315,
381
]
}
================================================
FILE: datainfo/single_pendulum/data_split_dict_3.json
================================================
{
"test": [
1194,
644,
1184,
23,
694,
620,
1067,
1157,
895,
1155,
486,
883,
555,
1153,
210,
1152,
1065,
942,
771,
1121,
283,
737,
642,
695,
86,
319,
539,
597,
835,
404,
64,
1096,
221,
157,
14,
634,
1161,
483,
1035,
571,
677,
773,
92,
90,
243,
850,
1167,
601,
41,
1181,
840,
136,
704,
181,
990,
987,
129,
254,
583,
546,
432,
213,
1109,
668,
334,
572,
58,
689,
475,
834,
1093,
718,
790,
1038,
862,
1172,
893,
528,
444,
1013,
278,
73,
199,
748,
274,
910,
808,
874,
793,
968,
551,
63,
616,
87,
326,
131,
31,
798,
1071,
310,
474,
308,
813,
975,
1125,
1107,
963,
392,
479,
1128,
531,
960,
26,
134,
1189,
970,
757,
267,
1114,
487
],
"val": [
710,
822,
925,
35,
46,
250,
829,
122,
293,
602,
878,
1029,
882,
232,
911,
349,
556,
295,
1190,
765,
678,
1079,
856,
467,
763,
33,
227,
734,
949,
972,
1145,
1158,
522,
637,
901,
852,
965,
1047,
4,
204,
159,
423,
1097,
36,
419,
581,
429,
297,
426,
649,
354,
1148,
277,
870,
812,
530,
298,
431,
132,
603,
353,
1051,
825,
175,
696,
570,
375,
645,
390,
69,
247,
460,
554,
923,
446,
1147,
163,
346,
897,
459,
683,
659,
208,
198,
891,
383,
671,
488,
1170,
455,
1024,
1115,
269,
55,
214,
772,
615,
540,
1086,
640,
379,
745,
363,
655,
611,
934,
514,
756,
43,
124,
45,
1002,
1119,
1074,
722,
954,
680,
123,
272,
1098
],
"train": [
464,
1062,
890,
22,
660,
427,
290,
167,
469,
1159,
341,
712,
560,
457,
77,
952,
639,
462,
1185,
1195,
654,
135,
83,
797,
730,
372,
650,
218,
225,
787,
285,
265,
343,
279,
452,
1028,
397,
951,
948,
421,
855,
625,
158,
846,
1004,
1009,
753,
72,
466,
559,
914,
80,
591,
738,
142,
1120,
847,
408,
708,
973,
552,
236,
364,
1186,
207,
606,
1103,
338,
741,
622,
282,
178,
991,
263,
1054,
1112,
612,
586,
1072,
1193,
371,
1129,
480,
66,
1176,
781,
691,
507,
735,
853,
91,
575,
386,
192,
928,
67,
335,
785,
656,
585,
220,
789,
374,
347,
706,
231,
196,
188,
752,
681,
348,
324,
1177,
309,
151,
727,
929,
525,
688,
100,
95,
1141,
1124,
352,
519,
1090,
624,
1075,
623,
266,
351,
219,
723,
1133,
470,
1085,
75,
38,
1150,
510,
1031,
998,
800,
8,
641,
676,
370,
104,
288,
766,
415,
286,
292,
521,
411,
42,
1182,
443,
1110,
1092,
811,
873,
1073,
731,
842,
977,
350,
561,
788,
424,
1179,
312,
876,
906,
141,
321,
228,
564,
294,
791,
414,
922,
76,
118,
328,
138,
1166,
1156,
356,
628,
714,
841,
777,
1142,
799,
803,
103,
875,
844,
638,
1063,
1055,
726,
953,
1113,
313,
518,
315,
449,
935,
1160,
65,
881,
482,
6,
633,
784,
296,
498,
10,
1081,
337,
331,
865,
867,
434,
1169,
380,
605,
394,
154,
889,
884,
393,
412,
143,
217,
500,
587,
783,
815,
1052,
945,
535,
262,
961,
1007,
635,
12,
284,
61,
1089,
767,
672,
1076,
1197,
607,
717,
1117,
327,
999,
866,
1149,
0,
399,
1010,
345,
230,
994,
667,
121,
1036,
49,
1144,
1199,
1101,
992,
489,
451,
249,
533,
979,
927,
137,
666,
595,
496,
770,
959,
212,
534,
1099,
978,
1100,
197,
89,
697,
686,
171,
913,
1017,
477,
1146,
211,
558,
827,
1108,
201,
156,
619,
1039,
179,
915,
1088,
679,
437,
169,
657,
96,
27,
246,
502,
627,
684,
473,
647,
1005,
53,
739,
1046,
303,
485,
851,
29,
955,
373,
576,
699,
172,
15,
1131,
287,
930,
725,
589,
238,
653,
148,
454,
304,
1171,
511,
59,
318,
170,
84,
97,
857,
1162,
48,
947,
112,
481,
252,
703,
941,
1026,
617,
981,
837,
1021,
16,
1042,
162,
664,
838,
760,
517,
1059,
251,
1033,
420,
1154,
886,
71,
1198,
325,
1104,
257,
1015,
299,
830,
1084,
614,
848,
506,
721,
966,
849,
715,
596,
1018,
107,
888,
759,
418,
305,
669,
504,
1014,
608,
1027,
536,
368,
19,
233,
1122,
224,
796,
1020,
447,
234,
698,
11,
545,
956,
750,
106,
216,
907,
34,
1135,
242,
899,
200,
82,
818,
450,
1140,
476,
209,
21,
402,
1053,
161,
248,
113,
387,
357,
819,
543,
833,
32,
1095,
1048,
916,
1087,
40,
366,
17,
396,
578,
1006,
439,
1080,
68,
983,
355,
682,
445,
168,
165,
648,
9,
898,
743,
39,
259,
119,
316,
1061,
950,
25,
472,
574,
782,
405,
513,
971,
1056,
456,
858,
674,
145,
289,
57,
986,
786,
1044,
1180,
709,
416,
821,
108,
1019,
940,
693,
828,
744,
646,
805,
177,
503,
239,
905,
2,
859,
526,
824,
764,
920,
317,
871,
562,
417,
1049,
631,
300,
378,
610,
206,
588,
174,
1043,
407,
149,
912,
621,
1008,
193,
189,
1139,
673,
565,
817,
155,
114,
816,
361,
806,
458,
413,
340,
362,
700,
253,
794,
516,
685,
845,
705,
832,
1187,
900,
1078,
754,
367,
401,
976,
755,
185,
969,
1041,
180,
1105,
768,
468,
652,
944,
854,
974,
629,
144,
573,
110,
1130,
742,
1060,
775,
153,
186,
724,
301,
60,
924,
520,
505,
936,
538,
939,
579,
1037,
1058,
860,
1050,
461,
377,
807,
1064,
99,
256,
1057,
448,
115,
442,
78,
281,
917,
885,
557,
932,
400,
872,
701,
215,
747,
497,
388,
1034,
339,
342,
30,
919,
344,
670,
463,
365,
1003,
1016,
270,
780,
120,
1173,
93,
894,
187,
320,
663,
173,
261,
762,
636,
191,
140,
264,
194,
746,
260,
769,
880,
728,
235,
478,
964,
903,
74,
988,
809,
1025,
1143,
594,
1188,
1094,
736,
491,
758,
147,
879,
329,
1196,
861,
1138,
176,
205,
391,
490,
1164,
1083,
314,
222,
1132,
692,
598,
609,
938,
687,
302,
550,
237,
826,
887,
599,
150,
957,
273,
139,
190,
863,
453,
720,
1151,
509,
98,
166,
152,
410,
658,
1000,
495,
702,
529,
690,
593,
582,
802,
425,
164,
422,
240,
512,
776,
85,
823,
569,
358,
406,
613,
376,
931,
403,
716,
630,
465,
255,
661,
1163,
1030,
79,
184,
761,
332,
563,
962,
566,
102,
567,
109,
943,
732,
541,
523,
779,
37,
1123,
306,
592,
291,
713,
493,
1118,
117,
1192,
804,
933,
24,
183,
384,
675,
101,
792,
94,
896,
1070,
864,
241,
146,
892,
5,
995,
1127,
105,
544,
577,
1174,
751,
111,
1040,
385,
542,
1178,
749,
982,
707,
1011,
1069,
268,
1134,
1045,
333,
28,
133,
436,
229,
435,
868,
604,
711,
276,
548,
223,
996,
18,
909,
428,
1191,
430,
869,
719,
778,
508,
1102,
381,
937,
441,
993,
44,
820,
651,
1082,
1032,
926,
665,
618,
471,
382,
1068,
182,
580,
81,
814,
389,
740,
733,
3,
568,
130,
126,
438,
336,
195,
271,
369,
311,
600,
398,
662,
395,
359,
1116,
322,
160,
323,
501,
1066,
527,
88,
729,
1126,
985,
494,
70,
902,
127,
47,
1136,
1168,
56,
877,
1,
839,
795,
330,
484,
52,
433,
532,
1106,
632,
275,
1137,
1012,
774,
51,
409,
1091,
584,
643,
499,
7,
843,
1175,
1022,
1165,
280,
810,
245,
958,
967,
836,
908,
547,
125,
202,
226,
360,
62,
801,
831,
997,
553,
989,
258,
128,
1183,
116,
626,
1111,
54,
1001,
549,
537,
20,
980,
244,
307,
515,
1023,
1077,
984,
492,
13,
50,
590,
440,
921,
904,
918,
203,
946,
524
]
}
================================================
FILE: datainfo/swingstick_magnetic/data_split_dict_1.json
================================================
{
"test": [
2,
18,
4
],
"val": [
3,
8
],
"train": [
19,
5,
13,
9,
20,
21,
11,
12,
0,
16,
1,
17,
6,
10,
7,
14,
15
]
}
================================================
FILE: datainfo/swingstick_magnetic/data_split_dict_2.json
================================================
{
"test": [
20,
2,
1
],
"val": [
5,
11
],
"train": [
18,
21,
13,
10,
4,
17,
7,
15,
6,
19,
12,
0,
14,
3,
16,
8,
9
]
}
================================================
FILE: datainfo/swingstick_magnetic/data_split_dict_3.json
================================================
{
"test": [
17,
18,
7
],
"val": [
11,
4
],
"train": [
1,
13,
16,
5,
10,
6,
19,
12,
14,
3,
8,
20,
21,
0,
9,
2,
15
]
}
================================================
FILE: datainfo/swingstick_non_magnetic/data_split_dict_1.json
================================================
{
"test": [
48,
60,
57,
63,
15,
32,
8,
72,
17
],
"val": [
77,
0,
55,
49,
3,
62,
12,
26
],
"train": [
69,
20,
80,
10,
4,
59,
67,
81,
5,
2,
37,
78,
83,
58,
71,
19,
38,
73,
25,
79,
70,
30,
9,
36,
51,
52,
53,
16,
54,
23,
47,
21,
7,
39,
11,
6,
45,
74,
50,
18,
65,
42,
44,
22,
75,
35,
31,
28,
14,
33,
76,
46,
27,
64,
43,
24,
56,
68,
66,
41,
61,
82,
1,
40,
13,
29,
34
]
}
================================================
FILE: datainfo/swingstick_non_magnetic/data_split_dict_2.json
================================================
{
"test": [
4,
27,
32,
39,
21,
46,
10,
11,
7
],
"val": [
64,
56,
47,
65,
50,
55,
20,
74
],
"train": [
0,
31,
70,
6,
49,
41,
2,
67,
17,
13,
62,
9,
5,
57,
19,
25,
78,
18,
37,
79,
45,
51,
83,
16,
69,
71,
36,
66,
12,
61,
30,
54,
38,
40,
22,
42,
72,
26,
28,
63,
53,
43,
23,
44,
77,
8,
48,
60,
52,
1,
14,
15,
82,
35,
81,
33,
68,
76,
24,
58,
73,
59,
29,
80,
3,
75,
34
]
}
================================================
FILE: datainfo/swingstick_non_magnetic/data_split_dict_3.json
================================================
{
"test": [
8,
74,
60,
77,
47,
16,
69,
75,
30
],
"val": [
68,
73,
24,
29,
70,
33,
78,
1
],
"train": [
3,
15,
14,
22,
21,
51,
46,
52,
35,
5,
41,
56,
40,
54,
62,
57,
53,
7,
64,
26,
45,
58,
11,
18,
12,
32,
67,
9,
34,
20,
44,
43,
80,
13,
31,
39,
66,
6,
23,
82,
28,
36,
25,
27,
61,
38,
83,
17,
76,
63,
2,
37,
48,
10,
4,
49,
42,
0,
79,
81,
72,
59,
55,
65,
71,
19,
50
]
}
================================================
FILE: dataset.py
================================================
import os
import sys
import json
import glob
import torch
import itertools
import numpy as np
from PIL import Image
from scipy import misc
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms
class NeuralPhysDataset(Dataset):
def __init__(self, data_filepath, flag, seed, object_name="double_pendulum"):
self.seed = seed
self.flag = flag
self.object_name = object_name
self.data_filepath = data_filepath
self.all_filelist = self.get_all_filelist()
def get_all_filelist(self):
filelist = []
obj_filepath = os.path.join(self.data_filepath, self.object_name)
# get the video ids based on training or testing data
with open(os.path.join('../datainfo', self.object_name, f'data_split_dict_{self.seed}.json'), 'r') as file:
seq_dict = json.load(file)
vid_list = seq_dict[self.flag]
# go through all the selected videos and get the triplets: input(t, t+1), output(t+2)
for vid_idx in vid_list:
seq_filepath = os.path.join(obj_filepath, str(vid_idx))
num_frames = len(os.listdir(seq_filepath))
suf = os.listdir(seq_filepath)[0].split('.')[-1]
for p_frame in range(num_frames - 3):
par_list = []
for p in range(4):
par_list.append(os.path.join(seq_filepath, str(p_frame + p) + '.' + suf))
filelist.append(par_list)
return filelist
def __len__(self):
return len(self.all_filelist)
# 0, 1 -> 2, 3
def __getitem__(self, idx):
par_list = self.all_filelist[idx]
data = []
for i in range(2):
data.append(self.get_data(par_list[i])) # 0, 1
data = torch.cat(data, 2)
target = []
target.append(self.get_data(par_list[-2])) # 2
target.append(self.get_data(par_list[-1])) # 3
target = torch.cat(target, 2)
filepath = '_'.join(par_list[0].split('/')[-2:])
return data, target, filepath
def get_data(self, filepath):
data = Image.open(filepath)
data = data.resize((128, 128))
data = np.array(data)
data = torch.tensor(data / 255.0)
data = data.permute(2, 0, 1).float()
return data
class NeuralPhysLatentDynamicsDataset(Dataset):
def __init__(self, data_filepath, flag, seed, object_name="double_pendulum"):
self.seed = seed
self.flag = flag
self.object_name = object_name
self.data_filepath = data_filepath
self.all_filelist = self.get_all_filelist()
def get_all_filelist(self):
filelist = []
obj_filepath = os.path.join(self.data_filepath, self.object_name)
# get the video ids based on training or testing data
with open(os.path.join('../datainfo', self.object_name, f'data_split_dict_{self.seed}.json'), 'r') as file:
seq_dict = json.load(file)
vid_list = seq_dict[self.flag]
# go through all the selected videos and get the triplets: input(t, t+1), output(t+2)
for vid_idx in vid_list:
seq_filepath = os.path.join(obj_filepath, str(vid_idx))
num_frames = len(os.listdir(seq_filepath))
suf = os.listdir(seq_filepath)[0].split('.')[-1]
for p_frame in range(num_frames - 5):
par_list = []
for p in range(6):
par_list.append(os.path.join(seq_filepath, str(p_frame + p) + '.' + suf))
filelist.append(par_list)
return filelist
def __len__(self):
return len(self.all_filelist)
# 0, 1 -> 2, 3
def __getitem__(self, idx):
par_list = self.all_filelist[idx]
data = []
for i in range(2):
data.append(self.get_data(par_list[i])) # 0, 1
data = torch.cat(data, 2)
target = []
target.append(self.get_data(par_list[2])) # 2
target.append(self.get_data(par_list[3])) # 3
target = torch.cat(target, 2)
target_target = []
target_target.append(self.get_data(par_list[-2])) # 4
target_target.append(self.get_data(par_list[-1])) # 5
target_target = torch.cat(target_target, 2)
filepath = '_'.join(par_list[0].split('/')[-2:])
return data, target, target_target, filepath
def get_data(self, filepath):
data = Image.open(filepath)
data = data.resize((128, 128))
data = np.array(data)
data = torch.tensor(data / 255.0)
data = data.permute(2, 0, 1).float()
return data
================================================
FILE: eval.py
================================================
import os
import sys
import glob
import yaml
import torch
import pprint
import shutil
import numpy as np
from tqdm import tqdm
from munch import munchify
from collections import OrderedDict
from models import VisDynamicsModel
from models_latentpred import VisLatentDynamicsModel
from dataset import NeuralPhysDataset
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.loggers import TensorBoardLogger
def mkdir(folder):
if os.path.exists(folder):
shutil.rmtree(folder)
os.makedirs(folder)
def load_config(filepath):
with open(filepath, 'r') as stream:
try:
trainer_params = yaml.safe_load(stream)
return trainer_params
except yaml.YAMLError as exc:
print(exc)
def seed(cfg):
torch.manual_seed(cfg.seed)
if cfg.if_cuda:
torch.cuda.manual_seed(cfg.seed)
def main():
config_filepath = str(sys.argv[1])
checkpoint_filepath = str(sys.argv[2])
checkpoint_filepath = glob.glob(os.path.join(checkpoint_filepath, '*.ckpt'))[0]
cfg = load_config(filepath=config_filepath)
pprint.pprint(cfg)
cfg = munchify(cfg)
seed(cfg)
seed_everything(cfg.seed)
log_dir = '_'.join([cfg.log_dir,
cfg.dataset,
cfg.model_name,
str(cfg.seed)])
model = VisDynamicsModel(lr=cfg.lr,
seed=cfg.seed,
if_cuda=cfg.if_cuda,
if_test=True,
gamma=cfg.gamma,
log_dir=log_dir,
train_batch=cfg.train_batch,
val_batch=cfg.val_batch,
test_batch=cfg.test_batch,
num_workers=cfg.num_workers,
model_name=cfg.model_name,
data_filepath=cfg.data_filepath,
dataset=cfg.dataset,
lr_schedule=cfg.lr_schedule)
ckpt = torch.load(checkpoint_filepath)
if 'refine' in cfg.model_name:
ckpt = rename_ckpt_for_multi_models(ckpt)
model.model.load_state_dict(ckpt)
high_dim_checkpoint_filepath = str(sys.argv[3])
high_dim_checkpoint_filepath = glob.glob(os.path.join(high_dim_checkpoint_filepath, '*.ckpt'))[0]
ckpt = torch.load(high_dim_checkpoint_filepath)
ckpt = rename_ckpt_for_multi_models(ckpt)
model.high_dim_model.load_state_dict(ckpt)
else:
model.load_state_dict(ckpt['state_dict'])
model.eval()
model.freeze()
trainer = Trainer(gpus=1,
deterministic=True,
amp_backend='native',
default_root_dir=log_dir,
val_check_interval=1.0)
trainer.test(model)
model.test_save()
def main_latentpred():
config_filepath = str(sys.argv[1])
checkpoint_filepath = str(sys.argv[2])
high_dim_checkpoint_filepath = str(sys.argv[3])
refine_checkpoint_filepath = str(sys.argv[4])
cfg = load_config(filepath=config_filepath)
pprint.pprint(cfg)
cfg = munchify(cfg)
seed(cfg)
seed_everything(cfg.seed)
log_dir = '_'.join([cfg.log_dir,
cfg.dataset,
cfg.model_name,
str(cfg.seed)])
model = VisLatentDynamicsModel(lr=cfg.lr,
seed=cfg.seed,
if_cuda=cfg.if_cuda,
if_test=True,
gamma=cfg.gamma,
log_dir=log_dir,
train_batch=cfg.train_batch,
val_batch=cfg.val_batch,
test_batch=cfg.test_batch,
num_workers=cfg.num_workers,
model_name=cfg.model_name,
data_filepath=cfg.data_filepath,
dataset=cfg.dataset,
lr_schedule=cfg.lr_schedule)
model.load_model(checkpoint_filepath)
model.load_high_dim_refine_model(high_dim_checkpoint_filepath, refine_checkpoint_filepath)
model.extract_decoder_from_refine_model()
model.eval()
model.freeze()
trainer = Trainer(gpus=1,
deterministic=True,
amp_backend='native',
default_root_dir=log_dir,
val_check_interval=1.0)
trainer.test(model)
def rename_ckpt_for_multi_models(ckpt):
renamed_state_dict = OrderedDict()
for k, v in ckpt['state_dict'].items():
if 'high_dim_model' in k:
name = k.replace('high_dim_model.', '')
else:
name = k.replace('model.', '')
renamed_state_dict[name] = v
return renamed_state_dict
# gather latent variables by running training data on the trained high-dim models
def gather_latent_from_trained_high_dim_model():
config_filepath = str(sys.argv[1])
checkpoint_filepath = str(sys.argv[2])
checkpoint_filepath = glob.glob(os.path.join(checkpoint_filepath, '*.ckpt'))[0]
cfg = load_config(filepath=config_filepath)
pprint.pprint(cfg)
cfg = munchify(cfg)
seed(cfg)
seed_everything(cfg.seed)
log_dir = '_'.join([cfg.log_dir,
cfg.dataset,
cfg.model_name,
str(cfg.seed)])
model = VisDynamicsModel(lr=cfg.lr,
seed=cfg.seed,
if_cuda=cfg.if_cuda,
if_test=True,
gamma=cfg.gamma,
log_dir=log_dir,
train_batch=cfg.train_batch,
val_batch=cfg.val_batch,
test_batch=cfg.test_batch,
num_workers=cfg.num_workers,
model_name=cfg.model_name,
data_filepath=cfg.data_filepath,
dataset=cfg.dataset,
lr_schedule=cfg.lr_schedule)
ckpt = torch.load(checkpoint_filepath)
model.load_state_dict(ckpt['state_dict'])
model = model.to('cuda')
model.eval()
model.freeze()
# prepare train and val dataset
kwargs = {'num_workers': cfg.num_workers, 'pin_memory': True} if cfg.if_cuda else {}
train_dataset = NeuralPhysDataset(data_filepath=cfg.data_filepath,
flag='train',
seed=cfg.seed,
object_name=cfg.dataset)
val_dataset = NeuralPhysDataset(data_filepath=cfg.data_filepath,
flag='val',
seed=cfg.seed,
object_name=cfg.dataset)
# prepare train and val loader
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
batch_size=cfg.train_batch,
shuffle=True,
**kwargs)
val_loader = torch.utils.data.DataLoader(dataset=val_dataset,
batch_size=cfg.val_batch,
shuffle=False,
**kwargs)
# run train forward pass to save the latent vector for training the refine network later
all_filepaths = []
all_latents = []
var_log_dir = os.path.join(log_dir, 'variables')
for batch_idx, (data, target, filepath) in enumerate(tqdm(train_loader)):
if cfg.model_name == 'encoder-decoder':
output, latent = model.model(data.cuda())
if cfg.model_name == 'encoder-decoder-64':
output, latent = model.model(data.cuda(), data.cuda(), False)
# save the latent vectors
all_filepaths.extend(filepath)
for idx in range(data.shape[0]):
latent_tmp = latent[idx].view(1, -1)[0]
latent_tmp = latent_tmp.cpu().detach().numpy()
all_latents.append(latent_tmp)
mkdir(var_log_dir+'_train')
np.save(os.path.join(var_log_dir+'_train', 'ids.npy'), all_filepaths)
np.save(os.path.join(var_log_dir+'_train', 'latent.npy'), all_latents)
# run val forward pass to save the latent vector for validating the refine network later
all_filepaths = []
all_latents = []
var_log_dir = os.path.join(log_dir, 'variables')
for batch_idx, (data, target, filepath) in enumerate(tqdm(val_loader)):
if cfg.model_name == 'encoder-decoder':
output, latent = model.model(data.cuda())
if cfg.model_name == 'encoder-decoder-64':
output, latent = model.model(data.cuda(), data.cuda(), False)
# save the latent vectors
all_filepaths.extend(filepath)
for idx in range(data.shape[0]):
latent_tmp = latent[idx].view(1, -1)[0]
latent_tmp = latent_tmp.cpu().detach().numpy()
all_latents.append(latent_tmp)
mkdir(var_log_dir+'_val')
np.save(os.path.join(var_log_dir+'_val', 'ids.npy'), all_filepaths)
np.save(os.path.join(var_log_dir+'_val', 'latent.npy'), all_latents)
# gather latent variables by running training data on the trained refine models
def gather_latent_from_trained_refine_model():
config_filepath = str(sys.argv[1])
checkpoint_filepath = str(sys.argv[2])
checkpoint_filepath = glob.glob(os.path.join(checkpoint_filepath, '*.ckpt'))[0]
cfg = load_config(filepath=config_filepath)
pprint.pprint(cfg)
cfg = munchify(cfg)
seed(cfg)
seed_everything(cfg.seed)
log_dir = '_'.join([cfg.log_dir,
cfg.dataset,
cfg.model_name,
str(cfg.seed)])
model = VisDynamicsModel(lr=cfg.lr,
seed=cfg.seed,
if_cuda=cfg.if_cuda,
if_test=True,
gamma=cfg.gamma,
log_dir=log_dir,
train_batch=cfg.train_batch,
val_batch=cfg.val_batch,
test_batch=cfg.test_batch,
num_workers=cfg.num_workers,
model_name=cfg.model_name,
data_filepath=cfg.data_filepath,
dataset=cfg.dataset,
lr_schedule=cfg.lr_schedule)
ckpt = torch.load(checkpoint_filepath)
ckpt = rename_ckpt_for_multi_models(ckpt)
model.model.load_state_dict(ckpt)
high_dim_checkpoint_filepath = str(sys.argv[3])
high_dim_checkpoint_filepath = glob.glob(os.path.join(high_dim_checkpoint_filepath, '*.ckpt'))[0]
ckpt = torch.load(high_dim_checkpoint_filepath)
ckpt = rename_ckpt_for_multi_models(ckpt)
model.high_dim_model.load_state_dict(ckpt)
model = model.to('cuda')
model.eval()
model.freeze()
# prepare train and val dataset
kwargs = {'num_workers': cfg.num_workers, 'pin_memory': True} if cfg.if_cuda else {}
train_dataset = NeuralPhysDataset(data_filepath=cfg.data_filepath,
flag='train',
seed=cfg.seed,
object_name=cfg.dataset)
val_dataset = NeuralPhysDataset(data_filepath=cfg.data_filepath,
flag='val',
seed=cfg.seed,
object_name=cfg.dataset)
# prepare train and val loader
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
batch_size=cfg.train_batch,
shuffle=True,
**kwargs)
val_loader = torch.utils.data.DataLoader(dataset=val_dataset,
batch_size=cfg.val_batch,
shuffle=False,
**kwargs)
# run train forward pass to save the latent vector for training data
all_filepaths = []
all_latents = []
all_refine_latents = []
all_reconstructed_latents = []
var_log_dir = os.path.join(log_dir, 'variables')
for batch_idx, (data, target, filepath) in enumerate(tqdm(train_loader)):
_, latent = model.high_dim_model(data.cuda(), data.cuda(), False)
latent = latent.squeeze(-1).squeeze(-1)
latent_reconstructed, latent_latent = model.model(latent)
# save the latent vectors
all_filepaths.extend(filepath)
for idx in range(data.shape[0]):
latent_tmp = latent[idx].view(1, -1)[0]
latent_tmp = latent_tmp.cpu().detach().numpy()
all_latents.append(latent_tmp)
# save latent_latent: the latent vector in the refine network
latent_latent_tmp = latent_latent[idx].view(1, -1)[0]
latent_latent_tmp = latent_latent_tmp.cpu().detach().numpy()
all_refine_latents.append(latent_latent_tmp)
# save latent_reconstructed: the latent vector reconstructed by the entire refine network
latent_reconstructed_tmp = latent_reconstructed[idx].view(1, -1)[0]
latent_reconstructed_tmp = latent_reconstructed_tmp.cpu().detach().numpy()
all_reconstructed_latents.append(latent_reconstructed_tmp)
mkdir(var_log_dir+'_train')
np.save(os.path.join(var_log_dir+'_train', 'ids.npy'), all_filepaths)
np.save(os.path.join(var_log_dir+'_train', 'latent.npy'), all_latents)
np.save(os.path.join(var_log_dir+'_train', 'refine_latent.npy'), all_refine_latents)
np.save(os.path.join(var_log_dir+'_train', 'reconstructed_latent.npy'), all_reconstructed_latents)
# run val forward pass to save the latent vector for validation data
all_filepaths = []
all_latents = []
all_refine_latents = []
all_reconstructed_latents = []
var_log_dir = os.path.join(log_dir, 'variables')
for batch_idx, (data, target, filepath) in enumerate(tqdm(val_loader)):
_, latent = model.high_dim_model(data.cuda(), data.cuda(), False)
latent = latent.squeeze(-1).squeeze(-1)
latent_reconstructed, latent_latent = model.model(latent)
# save the latent vectors
all_filepaths.extend(filepath)
for idx in range(data.shape[0]):
latent_tmp = latent[idx].view(1, -1)[0]
latent_tmp = latent_tmp.cpu().detach().numpy()
all_latents.append(latent_tmp)
# save latent_latent: the latent vector in the refine network
latent_latent_tmp = latent_latent[idx].view(1, -1)[0]
latent_latent_tmp = latent_latent_tmp.cpu().detach().numpy()
all_refine_latents.append(latent_latent_tmp)
# save latent_reconstructed: the latent vector reconstructed by the entire refine network
latent_reconstructed_tmp = latent_reconstructed[idx].view(1, -1)[0]
latent_reconstructed_tmp = latent_reconstructed_tmp.cpu().detach().numpy()
all_reconstructed_latents.append(latent_reconstructed_tmp)
mkdir(var_log_dir+'_val')
np.save(os.path.join(var_log_dir+'_val', 'ids.npy'), all_filepaths)
np.save(os.path.join(var_log_dir+'_val', 'latent.npy'), all_latents)
np.save(os.path.join(var_log_dir+'_val', 'refine_latent.npy'), all_refine_latents)
np.save(os.path.join(var_log_dir+'_val', 'reconstructed_latent.npy'), all_reconstructed_latents)
if __name__ == '__main__':
if str(sys.argv[4]) == 'eval-train':
gather_latent_from_trained_high_dim_model()
elif str(sys.argv[4]) == 'eval-refine-train':
gather_latent_from_trained_refine_model()
elif str(sys.argv[5]) == 'eval-latentpred':
main_latentpred()
else:
main()
================================================
FILE: main.py
================================================
import os
import sys
import yaml
import torch
import pprint
from munch import munchify
from models import VisDynamicsModel
from models_latentpred import VisLatentDynamicsModel
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint
def load_config(filepath):
with open(filepath, 'r') as stream:
try:
trainer_params = yaml.safe_load(stream)
return trainer_params
except yaml.YAMLError as exc:
print(exc)
def seed(cfg):
torch.manual_seed(cfg.seed)
if cfg.if_cuda:
torch.cuda.manual_seed(cfg.seed)
def main():
config_filepath = str(sys.argv[1])
cfg = load_config(filepath=config_filepath)
pprint.pprint(cfg)
cfg = munchify(cfg)
seed(cfg)
seed_everything(cfg.seed)
log_dir = '_'.join([cfg.log_dir,
cfg.dataset,
cfg.model_name,
str(cfg.seed)])
model = VisDynamicsModel(lr=cfg.lr,
seed=cfg.seed,
if_cuda=cfg.if_cuda,
if_test=False,
gamma=cfg.gamma,
log_dir=log_dir,
train_batch=cfg.train_batch,
val_batch=cfg.val_batch,
test_batch=cfg.test_batch,
num_workers=cfg.num_workers,
model_name=cfg.model_name,
data_filepath=cfg.data_filepath,
dataset=cfg.dataset,
lr_schedule=cfg.lr_schedule)
# define callback for selecting checkpoints during training
checkpoint_callback = ModelCheckpoint(
filepath=log_dir + "/lightning_logs/checkpoints/{epoch}_{val_loss}",
verbose=True,
monitor='val_loss',
mode='min',
prefix='')
# define trainer
trainer = Trainer(gpus=cfg.num_gpus,
max_epochs=cfg.epochs,
deterministic=True,
accelerator='ddp',
amp_backend='native',
default_root_dir=log_dir,
val_check_interval=1.0,
checkpoint_callback=checkpoint_callback)
trainer.fit(model)
def main_latentpred():
config_filepath = str(sys.argv[2])
high_dim_checkpoint_filepath = str(sys.argv[3])
refine_checkpoint_filepath = str(sys.argv[4])
cfg = load_config(filepath=config_filepath)
pprint.pprint(cfg)
cfg = munchify(cfg)
seed(cfg)
seed_everything(cfg.seed)
log_dir = '_'.join([cfg.log_dir,
cfg.dataset,
cfg.model_name,
str(cfg.seed)])
model = VisLatentDynamicsModel(lr=cfg.lr,
seed=cfg.seed,
if_cuda=cfg.if_cuda,
if_test=False,
gamma=cfg.gamma,
log_dir=log_dir,
train_batch=cfg.train_batch,
val_batch=cfg.val_batch,
test_batch=cfg.test_batch,
num_workers=cfg.num_workers,
model_name=cfg.model_name,
data_filepath=cfg.data_filepath,
dataset=cfg.dataset,
lr_schedule=cfg.lr_schedule)
model.load_high_dim_refine_model(high_dim_checkpoint_filepath, refine_checkpoint_filepath)
# define callback for selecting checkpoints during training
checkpoint_callback = ModelCheckpoint(
filepath=log_dir + "/lightning_logs/checkpoints/{epoch}_{val_loss}",
verbose=True,
monitor='val_loss',
mode='min',
prefix='')
# define trainer
trainer = Trainer(gpus=cfg.num_gpus,
max_epochs=cfg.epochs,
deterministic=True,
accelerator='ddp',
amp_backend='native',
default_root_dir=log_dir,
val_check_interval=1.0,
checkpoint_callback=checkpoint_callback)
trainer.fit(model)
if __name__ == '__main__':
if sys.argv[1] == 'latentpred':
main_latentpred()
else:
main()
================================================
FILE: model_utils.py
================================================
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
def conv2d_bn_relu(inch,outch,kernel_size,stride=1,padding=1):
convlayer = torch.nn.Sequential(
torch.nn.Conv2d(inch,outch,kernel_size=kernel_size,stride=stride,padding=padding),
torch.nn.BatchNorm2d(outch),
torch.nn.ReLU()
)
return convlayer
def conv2d_bn_sigmoid(inch,outch,kernel_size,stride=1,padding=1):
convlayer = torch.nn.Sequential(
torch.nn.Conv2d(inch,outch,kernel_size=kernel_size,stride=stride,padding=padding),
torch.nn.BatchNorm2d(outch),
torch.nn.Sigmoid()
)
return convlayer
def deconv_sigmoid(inch,outch,kernel_size,stride=1,padding=1):
convlayer = torch.nn.Sequential(
torch.nn.ConvTranspose2d(inch,outch,kernel_size=kernel_size,stride=stride,padding=padding),
torch.nn.Sigmoid()
)
return convlayer
def deconv_relu(inch,outch,kernel_size,stride=1,padding=1):
convlayer = torch.nn.Sequential(
torch.nn.ConvTranspose2d(inch,outch,kernel_size=kernel_size,stride=stride,padding=padding),
torch.nn.BatchNorm2d(outch),
torch.nn.ReLU()
)
return convlayer
class EncoderDecoder(torch.nn.Module):
def __init__(self, in_channels):
super(EncoderDecoder,self).__init__()
self.conv_stack1 = torch.nn.Sequential(
conv2d_bn_relu(in_channels,32,4,stride=2),
conv2d_bn_relu(32,32,3)
)
self.conv_stack2 = torch.nn.Sequential(
conv2d_bn_relu(32,32,4,stride=2),
conv2d_bn_relu(32,32,3)
)
self.conv_stack3 = torch.nn.Sequential(
conv2d_bn_relu(32,64,4,stride=2),
conv2d_bn_relu(64,64,3)
)
self.conv_stack4 = torch.nn.Sequential(
conv2d_bn_relu(64,128,4,stride=2),
conv2d_bn_relu(128,128,3),
)
self.conv_stack5 = torch.nn.Sequential(
conv2d_bn_relu(128,128,(3,4),stride=(1,2)),
conv2d_bn_relu(128,128,3),
)
self.deconv_5 = deconv_relu(128,64,(3,4),stride=(1,2))
self.deconv_4 = deconv_relu(67,64,4,stride=2)
self.deconv_3 = deconv_relu(67,32,4,stride=2)
self.deconv_2 = deconv_relu(35,16,4,stride=2)
self.deconv_1 = deconv_sigmoid(19,3,4,stride=2)
self.predict_5 = torch.nn.Conv2d(128,3,3,stride=1,padding=1)
self.predict_4 = torch.nn.Conv2d(67,3,3,stride=1,padding=1)
self.predict_3 = torch.nn.Conv2d(67,3,3,stride=1,padding=1)
self.predict_2 = torch.nn.Conv2d(35,3,3,stride=1,padding=1)
self.up_sample_5 = torch.nn.Sequential(
torch.nn.ConvTranspose2d(3,3,(3,4),stride=(1,2),padding=1,bias=False),
torch.nn.Sigmoid()
)
self.up_sample_4 = torch.nn.Sequential(
torch.nn.ConvTranspose2d(3,3,4,stride=2,padding=1,bias=False),
torch.nn.Sigmoid()
)
self.up_sample_3 = torch.nn.Sequential(
torch.nn.ConvTranspose2d(3,3,4,stride=2,padding=1,bias=False),
torch.nn.Sigmoid()
)
self.up_sample_2 = torch.nn.Sequential(
torch.nn.ConvTranspose2d(3,3,4,stride=2,padding=1,bias=False),
torch.nn.Sigmoid()
)
def encoder(self, x):
conv1_out = self.conv_stack1(x)
conv2_out = self.conv_stack2(conv1_out)
conv3_out = self.conv_stack3(conv2_out)
conv4_out = self.conv_stack4(conv3_out)
conv5_out = self.conv_stack5(conv4_out)
return conv5_out
def decoder(self, x):
deconv5_out = self.deconv_5(x)
predict_5_out = self.up_sample_5(self.predict_5(x))
concat_5 = torch.cat([deconv5_out, predict_5_out],dim=1)
deconv4_out = self.deconv_4(concat_5)
predict_4_out = self.up_sample_4(self.predict_4(concat_5))
concat_4 = torch.cat([deconv4_out,predict_4_out],dim=1)
deconv3_out = self.deconv_3(concat_4)
predict_3_out = self.up_sample_3(self.predict_3(concat_4))
concat2 = torch.cat([deconv3_out,predict_3_out],dim=1)
deconv2_out = self.deconv_2(concat2)
predict_2_out = self.up_sample_2(self.predict_2(concat2))
concat1 = torch.cat([deconv2_out,predict_2_out],dim=1)
predict_out = self.deconv_1(concat1)
return predict_out
def forward(self,x):
latent = self.encoder(x)
out = self.decoder(latent)
return out, latent
class EncoderDecoder64x1x1(torch.nn.Module):
def __init__(self, in_channels):
super(EncoderDecoder64x1x1,self).__init__()
self.conv_stack1 = torch.nn.Sequential(
conv2d_bn_relu(in_channels,32,4,stride=2),
conv2d_bn_relu(32,32,3)
)
self.conv_stack2 = torch.nn.Sequential(
conv2d_bn_relu(32,32,4,stride=2),
conv2d_bn_relu(32,32,3)
)
self.conv_stack3 = torch.nn.Sequential(
conv2d_bn_relu(32,64,4,stride=2),
conv2d_bn_relu(64,64,3)
)
self.conv_stack4 = torch.nn.Sequential(
conv2d_bn_relu(64,64,4,stride=2),
conv2d_bn_relu(64,64,3),
)
self.conv_stack5 = torch.nn.Sequential(
conv2d_bn_relu(64,64,4,stride=2),
conv2d_bn_relu(64,64,3),
)
self.conv_stack6 = torch.nn.Sequential(
conv2d_bn_relu(64,64,4,stride=2),
conv2d_bn_relu(64,64,3),
)
self.conv_stack7 = torch.nn.Sequential(
conv2d_bn_relu(64,64,4,stride=2),
conv2d_bn_relu(64,64,3),
)
self.conv_stack8 = torch.nn.Sequential(
conv2d_bn_relu(64, 64, (3,4), stride=(1,2)),
conv2d_bn_relu(64, 64, 3),
)
self.deconv_8 = deconv_relu(64,64,(3,4),stride=(1,2))
self.deconv_7 = deconv_relu(67,64,4,stride=2)
self.deconv_6 = deconv_relu(67,64,4,stride=2)
self.deconv_5 = deconv_relu(67,64,4,stride=2)
self.deconv_4 = deconv_relu(67,64,4,stride=2)
self.deconv_3 = deconv_relu(67,32,4,stride=2)
self.deconv_2 = deconv_relu(35,16,4,stride=2)
self.deconv_1 = deconv_sigmoid(19,3,4,stride=2)
self.predict_8 = torch.nn.Conv2d(64,3,3,stride=1,padding=1)
self.predict_7 = torch.nn.Conv2d(67,3,3,stride=1,padding=1)
self.predict_6 = torch.nn.Conv2d(67,3,3,stride=1,padding=1)
self.predict_5 = torch.nn.Conv2d(67,3,3,stride=1,padding=1)
self.predict_4 = torch.nn.Conv2d(67,3,3,stride=1,padding=1)
self.predict_3 = torch.nn.Conv2d(67,3,3,stride=1,padding=1)
self.predict_2 = torch.nn.Conv2d(35,3,3,stride=1,padding=1)
self.up_sample_8 = torch.nn.Sequential(
torch.nn.ConvTranspose2d(3,3,(3,4),stride=(1,2),padding=1,bias=False),
torch.nn.Sigmoid()
)
self.up_sample_7 = torch.nn.Sequential(
torch.nn.ConvTranspose2d(3,3,4,stride=2,padding=1,bias=False),
torch.nn.Sigmoid()
)
self.up_sample_6 = torch.nn.Sequential(
torch.nn.ConvTranspose2d(3,3,4,stride=2,padding=1,bias=False),
torch.nn.Sigmoid()
)
self.up_sample_5 = torch.nn.Sequential(
torch.nn.ConvTranspose2d(3,3,4,stride=2,padding=1,bias=False),
torch.nn.Sigmoid()
)
self.up_sample_4 = torch.nn.Sequential(
torch.nn.ConvTranspose2d(3,3,4,stride=2,padding=1,bias=False),
torch.nn.Sigmoid()
)
self.up_sample_3 = torch.nn.Sequential(
torch.nn.ConvTranspose2d(3,3,4,stride=2,padding=1,bias=False),
torch.nn.Sigmoid()
)
self.up_sample_2 = torch.nn.Sequential(
torch.nn.ConvTranspose2d(3,3,4,stride=2,padding=1,bias=False),
torch.nn.Sigmoid()
)
def encoder(self, x):
conv1_out = self.conv_stack1(x)
conv2_out = self.conv_stack2(conv1_out)
conv3_out = self.conv_stack3(conv2_out)
conv4_out = self.conv_stack4(conv3_out)
conv5_out = self.conv_stack5(conv4_out)
conv6_out = self.conv_stack6(conv5_out)
conv7_out = self.conv_stack7(conv6_out)
conv8_out = self.conv_stack8(conv7_out)
return conv8_out
def decoder(self, x):
deconv8_out = self.deconv_8(x)
predict_8_out = self.up_sample_8(self.predict_8(x))
concat_7 = torch.cat([deconv8_out, predict_8_out], dim=1)
deconv7_out = self.deconv_7(concat_7)
predict_7_out = self.up_sample_7(self.predict_7(concat_7))
concat_6 = torch.cat([deconv7_out,predict_7_out],dim=1)
deconv6_out = self.deconv_6(concat_6)
predict_6_out = self.up_sample_6(self.predict_6(concat_6))
concat_5 = torch.cat([deconv6_out,predict_6_out],dim=1)
deconv5_out = self.deconv_5(concat_5)
predict_5_out = self.up_sample_5(self.predict_5(concat_5))
concat_4 = torch.cat([deconv5_out,predict_5_out],dim=1)
deconv4_out = self.deconv_4(concat_4)
predict_4_out = self.up_sample_4(self.predict_4(concat_4))
concat_3 = torch.cat([deconv4_out,predict_4_out],dim=1)
deconv3_out = self.deconv_3(concat_3)
predict_3_out = self.up_sample_3(self.predict_3(concat_3))
concat2 = torch.cat([deconv3_out,predict_3_out],dim=1)
deconv2_out = self.deconv_2(concat2)
predict_2_out = self.up_sample_2(self.predict_2(concat2))
concat1 = torch.cat([deconv2_out,predict_2_out],dim=1)
predict_out = self.deconv_1(concat1)
return predict_out
def forward(self,x, reconstructed_latent, refine_latent):
if refine_latent != True:
latent = self.encoder(x)
out = self.decoder(latent)
return out, latent
else:
latent = self.encoder(x)
out = self.decoder(reconstructed_latent)
return out, latent
class SirenLayer(nn.Module):
def __init__(self, in_f, out_f, w0=30, is_first=False, is_last=False):
super().__init__()
self.in_f = in_f
self.w0 = w0
self.linear = nn.Linear(in_f, out_f)
self.is_first = is_first
self.is_last = is_last
self.init_weights()
def init_weights(self):
b = 1 / self.in_f if self.is_first else np.sqrt(6 / self.in_f) / self.w0
with torch.no_grad():
self.linear.weight.uniform_(-b, b)
def forward(self, x):
x = self.linear(x)
return x if self.is_last else torch.sin(self.w0 * x)
class RefineDoublePendulumModel(torch.nn.Module):
def __init__(self, in_channels):
super(RefineDoublePendulumModel, self).__init__()
self.layer1 = SirenLayer(in_channels, 128, is_first=True)
self.layer2 = SirenLayer(128, 64)
self.layer3 = SirenLayer(64, 32)
self.layer4 = SirenLayer(32, 4)
self.layer5 = SirenLayer(4, 32)
self.layer6 = SirenLayer(32, 64)
self.layer7 = SirenLayer(64, 128)
self.layer8 = SirenLayer(128, in_channels, is_last=True)
def forward(self, x):
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
latent = self.layer4(x)
x = self.layer5(latent)
x = self.layer6(x)
x = self.layer7(x)
x = self.layer8(x)
return x, latent
class RefineElasticPendulumModel(torch.nn.Module):
def __init__(self, in_channels):
super(RefineElasticPendulumModel, self).__init__()
self.layer1 = SirenLayer(in_channels, 128, is_first=True)
self.layer2 = SirenLayer(128, 64)
self.layer3 = SirenLayer(64, 32)
self.layer4 = SirenLayer(32, 6)
self.layer5 = SirenLayer(6, 32)
self.layer6 = SirenLayer(32, 64)
self.layer7 = SirenLayer(64, 128)
self.layer8 = SirenLayer(128, in_channels, is_last=True)
def forward(self, x):
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
latent = self.layer4(x)
x = self.layer5(latent)
x = self.layer6(x)
x = self.layer7(x)
x = self.layer8(x)
return x, latent
class RefineReactionDiffusionModel(torch.nn.Module):
def __init__(self, in_channels):
super(RefineReactionDiffusionModel, self).__init__()
self.layer1 = SirenLayer(in_channels, 128, is_first=True)
self.layer2 = SirenLayer(128, 64)
self.layer3 = SirenLayer(64, 32)
self.layer4 = SirenLayer(32, 2)
self.layer5 = SirenLayer(2, 32)
self.layer6 = SirenLayer(32, 64)
self.layer7 = SirenLayer(64, 128)
self.layer8 = SirenLayer(128, in_channels, is_last=True)
def forward(self, x):
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
latent = self.layer4(x)
x = self.layer5(latent)
x = self.layer6(x)
x = self.layer7(x)
x = self.layer8(x)
return x, latent
class RefineSwingStickNonMagneticModel(torch.nn.Module):
def __init__(self, in_channels):
super(RefineSwingStickNonMagneticModel, self).__init__()
self.layer1 = SirenLayer(in_channels, 128, is_first=True)
self.layer2 = SirenLayer(128, 64)
self.layer3 = SirenLayer(64, 32)
self.layer4 = SirenLayer(32, 4)
self.layer5 = SirenLayer(4, 32)
self.layer6 = SirenLayer(32, 64)
self.layer7 = SirenLayer(64, 128)
self.layer8 = SirenLayer(128, in_channels, is_last=True)
def forward(self, x):
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
latent = self.layer4(x)
x = self.layer5(latent)
x = self.layer6(x)
x = self.layer7(x)
x = self.layer8(x)
return x, latent
class RefineSinglePendulumModel(torch.nn.Module):
def __init__(self, in_channels):
super(RefineSinglePendulumModel, self).__init__()
self.layer1 = SirenLayer(in_channels, 128, is_first=True)
self.layer2 = SirenLayer(128, 64)
self.layer3 = SirenLayer(64, 32)
self.layer4 = SirenLayer(32, 2)
self.layer5 = SirenLayer(2, 32)
self.layer6 = SirenLayer(32, 64)
self.layer7 = SirenLayer(64, 128)
self.layer8 = SirenLayer(128, in_channels, is_last=True)
def forward(self, x):
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
latent = self.layer4(x)
x = self.layer5(latent)
x = self.layer6(x)
x = self.layer7(x)
x = self.layer8(x)
return x, latent
class RefineCircularMotionModel(torch.nn.Module):
def __init__(self, in_channels):
super(RefineCircularMotionModel, self).__init__()
self.layer1 = SirenLayer(in_channels, 128, is_first=True)
self.layer2 = SirenLayer(128, 64)
self.layer3 = SirenLayer(64, 32)
self.layer4 = SirenLayer(32, 2)
self.layer5 = SirenLayer(2, 32)
self.layer6 = SirenLayer(32, 64)
self.layer7 = SirenLayer(64, 128)
self.layer8 = SirenLayer(128, in_channels, is_last=True)
def forward(self, x):
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
latent = self.layer4(x)
x = self.layer5(latent)
x = self.layer6(x)
x = self.layer7(x)
x = self.layer8(x)
return x, latent
class RefineAirDancerModel(torch.nn.Module):
def __init__(self, in_channels):
super(RefineAirDancerModel, self).__init__()
self.layer1 = SirenLayer(in_channels, 128, is_first=True)
self.layer2 = SirenLayer(128, 64)
self.layer3 = SirenLayer(64, 32)
self.layer4 = SirenLayer(32, 8)
self.layer5 = SirenLayer(8, 32)
self.layer6 = SirenLayer(32, 64)
self.layer7 = SirenLayer(64, 128)
self.layer8 = SirenLayer(128, in_channels, is_last=True)
def forward(self, x):
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
latent = self.layer4(x)
x = self.layer5(latent)
x = self.layer6(x)
x = self.layer7(x)
x = self.layer8(x)
return x, latent
class RefineLavaLampModel(torch.nn.Module):
def __init__(self, in_channels):
super(RefineLavaLampModel, self).__init__()
self.layer1 = SirenLayer(in_channels, 128, is_first=True)
self.layer2 = SirenLayer(128, 64)
self.layer3 = SirenLayer(64, 32)
self.layer4 = SirenLayer(32, 8)
self.layer5 = SirenLayer(8, 32)
self.layer6 = SirenLayer(32, 64)
self.layer7 = SirenLayer(64, 128)
self.layer8 = SirenLayer(128, in_channels, is_last=True)
def forward(self, x):
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
latent = self.layer4(x)
x = self.layer5(latent)
x = self.layer6(x)
x = self.layer7(x)
x = self.layer8(x)
return x, latent
class RefineFireModel(torch.nn.Module):
def __init__(self, in_channels):
super(RefineFireModel, self).__init__()
self.layer1 = SirenLayer(in_channels, 128, is_first=True)
self.layer2 = SirenLayer(128, 64)
self.layer3 = SirenLayer(64, 32)
self.layer4 = SirenLayer(32, 24)
self.layer5 = SirenLayer(24, 32)
self.layer6 = SirenLayer(32, 64)
self.layer7 = SirenLayer(64, 128)
self.layer8 = SirenLayer(128, in_channels, is_last=True)
def forward(self, x):
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
latent = self.layer4(x)
x = self.layer5(latent)
x = self.layer6(x)
x = self.layer7(x)
x = self.layer8(x)
return x, latent
class RefineModelReLU(torch.nn.Module):
def __init__(self, in_channels):
super(RefineModelReLU, self).__init__()
self.layer1 = nn.Linear(in_channels, 128)
self.relu1 = nn.ReLU()
self.layer2 = nn.Linear(128, 64)
self.relu2 = nn.ReLU()
self.layer3 = nn.Linear(64, 4)
self.relu3 = nn.ReLU()
self.layer4 = nn.Linear(4, 64)
self.relu4 = nn.ReLU()
self.layer5 = nn.Linear(64, 128)
self.relu5 = nn.ReLU()
self.layer6 = nn.Linear(128, in_channels)
def forward(self, x):
x = self.relu1(self.layer1(x))
x = self.relu2(self.layer2(x))
latent = self.relu3(self.layer3(x))
x = self.relu4(self.layer4(latent))
x = self.relu5(self.layer5(x))
x = self.layer6(x)
return x, latent
class LatentPredModel(torch.nn.Module):
def __init__(self, in_channels):
super(LatentPredModel, self).__init__()
self.layer1 = nn.Linear(in_channels, 32)
self.relu1 = nn.ReLU()
self.layer2 = nn.Linear(32, 64)
self.relu2 = nn.ReLU()
self.layer3 = nn.Linear(64, 64)
self.relu3 = nn.ReLU()
self.layer4 = nn.Linear(64, 64)
self.relu4 = nn.ReLU()
self.layer5 = nn.Linear(64, 32)
self.relu5 = nn.ReLU()
self.layer6 = nn.Linear(32, in_channels)
def forward(self, x):
x = self.relu1(self.layer1(x))
x = self.relu2(self.layer2(x))
x = self.relu3(self.layer3(x))
x = self.relu4(self.layer4(x))
x = self.relu5(self.layer5(x))
x = self.layer6(x)
return x
================================================
FILE: models.py
================================================
import os
import torch
import shutil
import numpy as np
from torch import nn
import pytorch_lightning as pl
import torch.nn.functional as F
from torchvision import transforms
import torchvision.models as models
from collections import OrderedDict
from torch.utils.data import DataLoader
from torchvision.utils import save_image
from dataset import NeuralPhysDataset
from model_utils import (EncoderDecoder,
EncoderDecoder64x1x1,
RefineDoublePendulumModel,
RefineSinglePendulumModel,
RefineCircularMotionModel,
RefineModelReLU,
RefineSwingStickNonMagneticModel,
RefineAirDancerModel,
RefineLavaLampModel,
RefineFireModel,
RefineElasticPendulumModel,
RefineReactionDiffusionModel)
def mkdir(folder):
if os.path.exists(folder):
shutil.rmtree(folder)
os.makedirs(folder)
class VisDynamicsModel(pl.LightningModule):
def __init__(self,
lr: float=1e-4,
seed: int=1,
if_cuda: bool=True,
if_test: bool=False,
gamma: float=0.5,
log_dir: str='logs',
train_batch: int=512,
val_batch: int=256,
test_batch: int=256,
num_workers: int=8,
model_name: str='encoder-decoder-64',
data_filepath: str='data',
dataset: str='single_pendulum',
lr_schedule: list=[20, 50, 100]) -> None:
super().__init__()
self.save_hyperparameters()
self.kwargs = {'num_workers': self.hparams.num_workers, 'pin_memory': True} if self.hparams.if_cuda else {}
# create visualization saving folder if testing
self.pred_log_dir = os.path.join(self.hparams.log_dir, 'predictions')
self.var_log_dir = os.path.join(self.hparams.log_dir, 'variables')
if not self.hparams.if_test:
mkdir(self.pred_log_dir)
mkdir(self.var_log_dir)
self.__build_model()
def __build_model(self):
# model
if self.hparams.model_name == 'encoder-decoder':
self.model = EncoderDecoder(in_channels=3)
if self.hparams.model_name == 'encoder-decoder-64':
self.model = EncoderDecoder64x1x1(in_channels=3)
if self.hparams.model_name == 'refine-64' and self.hparams.dataset == 'single_pendulum':
self.model = RefineSinglePendulumModel(in_channels=64)
if self.hparams.model_name == 'refine-64' and self.hparams.dataset == 'double_pendulum':
self.model = RefineDoublePendulumModel(in_channels=64)
if self.hparams.model_name == 'refine-64' and self.hparams.dataset == 'circular_motion':
self.model = RefineCircularMotionModel(in_channels=64)
if self.hparams.model_name == 'refine-64' and self.hparams.dataset == 'swingstick_non_magnetic':
self.model = RefineSwingStickNonMagneticModel(in_channels=64)
if self.hparams.model_name == 'refine-64' and self.hparams.dataset == 'air_dancer':
self.model = RefineAirDancerModel(in_channels=64)
if self.hparams.model_name == 'refine-64' and self.hparams.dataset == 'lava_lamp':
self.model = RefineLavaLampModel(in_channels=64)
if self.hparams.model_name == 'refine-64' and self.hparams.dataset == 'fire':
self.model = RefineFireModel(in_channels=64)
if self.hparams.model_name == 'refine-64' and self.hparams.dataset == 'elastic_pendulum':
self.model = RefineElasticPendulumModel(in_channels=64)
if self.hparams.model_name == 'refine-64' and self.hparams.dataset == 'reaction_diffusion':
self.model = RefineReactionDiffusionModel(in_channels=64)
if 'refine' in self.hparams.model_name and self.hparams.if_test:
self.high_dim_model = EncoderDecoder64x1x1(in_channels=3)
# loss
self.loss_func = nn.MSELoss()
def train_forward(self, x):
if self.hparams.model_name == 'encoder-decoder' or 'refine' in self.hparams.model_name:
output, latent = self.model(x)
if self.hparams.model_name == 'encoder-decoder-64':
output, latent = self.model(x, x, False)
return output, latent
def training_step(self, batch, batch_idx):
data, target, filepath = batch
output, latent = self.train_forward(data)
train_loss = self.loss_func(output, target)
self.log('train_loss', train_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
return train_loss
def validation_step(self, batch, batch_idx):
data, target, filepath = batch
output, latent = self.train_forward(data)
val_loss = self.loss_func(output, target)
self.log('val_loss', val_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
return val_loss
def test_step(self, batch, batch_idx):
if self.hparams.model_name == 'encoder-decoder' or self.hparams.model_name == 'encoder-decoder-64':
data, target, filepath = batch
if self.hparams.model_name == 'encoder-decoder':
output, latent = self.model(data)
if self.hparams.model_name == 'encoder-decoder-64':
output, latent = self.model(data, data, False)
test_loss = self.loss_func(output, target)
self.log('test_loss', test_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
# save the output images and latent vectors
self.all_filepaths.extend(filepath)
for idx in range(data.shape[0]):
comparison = torch.cat([data[idx,:, :, :128].unsqueeze(0),
data[idx,:, :, 128:].unsqueeze(0),
target[idx, :, :, :128].unsqueeze(0),
target[idx, :, :, 128:].unsqueeze(0),
output[idx, :, :, :128].unsqueeze(0),
output[idx, :, :, 128:].unsqueeze(0)])
save_image(comparison.cpu(), os.path.join(self.pred_log_dir, filepath[idx]), nrow=1)
latent_tmp = latent[idx].view(1, -1)[0]
latent_tmp = latent_tmp.cpu().detach().numpy()
self.all_latents.append(latent_tmp)
if 'refine' in self.hparams.model_name:
data, target, filepath = batch
_, latent = self.high_dim_model(data, data, False)
latent = latent.squeeze(-1).squeeze(-1)
latent_reconstructed, latent_latent = self.model(latent)
output, _ = self.high_dim_model(data, latent_reconstructed.unsqueeze(2).unsqueeze(3), True)
# calculate losses
pixel_reconstruction_loss = self.loss_func(output, target)
test_loss = self.loss_func(latent_reconstructed, latent)
self.log('test_loss', test_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
self.log('pixel_reconstruction_loss', pixel_reconstruction_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
# save the output images and latent vectors
self.all_filepaths.extend(filepath)
for idx in range(data.shape[0]):
comparison = torch.cat([data[idx, :, :, :128].unsqueeze(0),
data[idx, :, :, 128:].unsqueeze(0),
target[idx, :, :, :128].unsqueeze(0),
target[idx, :, :, 128:].unsqueeze(0),
output[idx, :, :, :128].unsqueeze(0),
output[idx, :, :, 128:].unsqueeze(0)])
save_image(comparison.cpu(), os.path.join(self.pred_log_dir, filepath[idx]), nrow=1)
latent_tmp = latent[idx].view(1, -1)[0]
latent_tmp = latent_tmp.cpu().detach().numpy()
self.all_latents.append(latent_tmp)
# save latent_latent: the latent vector in the refine network
latent_latent_tmp = latent_latent[idx].view(1, -1)[0]
latent_latent_tmp = latent_latent_tmp.cpu().detach().numpy()
self.all_refine_latents.append(latent_latent_tmp)
# save latent_reconstructed: the latent vector reconstructed by the entire refine network
latent_reconstructed_tmp = latent_reconstructed[idx].view(1, -1)[0]
latent_reconstructed_tmp = latent_reconstructed_tmp.cpu().detach().numpy()
self.all_reconstructed_latents.append(latent_reconstructed_tmp)
def test_save(self):
if self.hparams.model_name == 'encoder-decoder' or self.hparams.model_name == 'encoder-decoder-64':
np.save(os.path.join(self.var_log_dir, 'ids.npy'), self.all_filepaths)
np.save(os.path.join(self.var_log_dir, 'latent.npy'), self.all_latents)
if 'refine' in self.hparams.model_name:
np.save(os.path.join(self.var_log_dir, 'ids.npy'), self.all_filepaths)
np.save(os.path.join(self.var_log_dir, 'latent.npy'), self.all_latents)
np.save(os.path.join(self.var_log_dir, 'refine_latent.npy'), self.all_refine_latents)
np.save(os.path.join(self.var_log_dir, 'reconstructed_latent.npy'), self.all_reconstructed_latents)
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=self.hparams.lr_schedule, gamma=self.hparams.gamma)
return [optimizer], [scheduler]
def paths_to_tuple(self, paths):
new_paths = []
for i in range(len(paths)):
tmp = paths[i].split('.')[0].split('_')
new_paths.append((int(tmp[0]), int(tmp[1])))
return new_paths
def setup(self, stage=None):
if stage == 'fit':
# for the training of the refine network, we need to have the latent data as the dataset
if 'refine' in self.hparams.model_name:
high_dim_var_log_dir = self.var_log_dir.replace('refine', 'encoder-decoder')
train_data = torch.FloatTensor(np.load(os.path.join(high_dim_var_log_dir+'_train', 'latent.npy')))
train_target = torch.FloatTensor(np.load(os.path.join(high_dim_var_log_dir+'_train', 'latent.npy')))
val_data = torch.FloatTensor(np.load(os.path.join(high_dim_var_log_dir+'_val', 'latent.npy')))
val_target = torch.FloatTensor(np.load(os.path.join(high_dim_var_log_dir+'_val', 'latent.npy')))
train_filepaths = list(np.load(os.path.join(high_dim_var_log_dir+'_train', 'ids.npy')))
val_filepaths = list(np.load(os.path.join(high_dim_var_log_dir+'_val', 'ids.npy')))
# convert the file strings into tuple so that we can use TensorDataset to load everything together
train_filepaths = torch.Tensor(self.paths_to_tuple(train_filepaths))
val_filepaths = torch.Tensor(self.paths_to_tuple(val_filepaths))
self.train_dataset = torch.utils.data.TensorDataset(train_data, train_target, train_filepaths)
self.val_dataset = torch.utils.data.TensorDataset(val_data, val_target, val_filepaths)
else:
self.train_dataset = NeuralPhysDataset(data_filepath=self.hparams.data_filepath,
flag='train',
seed=self.hparams.seed,
object_name=self.hparams.dataset)
self.val_dataset = NeuralPhysDataset(data_filepath=self.hparams.data_filepath,
flag='val',
seed=self.hparams.seed,
object_name=self.hparams.dataset)
if stage == 'test':
self.test_dataset = NeuralPhysDataset(data_filepath=self.hparams.data_filepath,
flag='test',
seed=self.hparams.seed,
object_name=self.hparams.dataset)
# initialize lists for saving variables and latents during testing
self.all_filepaths = []
self.all_latents = []
self.all_refine_latents = []
self.all_reconstructed_latents = []
def train_dataloader(self):
train_loader = torch.utils.data.DataLoader(dataset=self.train_dataset,
batch_size=self.hparams.train_batch,
shuffle=True,
**self.kwargs)
return train_loader
def val_dataloader(self):
val_loader = torch.utils.data.DataLoader(dataset=self.val_dataset,
batch_size=self.hparams.val_batch,
shuffle=False,
**self.kwargs)
return val_loader
def test_dataloader(self):
test_loader = torch.utils.data.DataLoader(dataset=self.test_dataset,
batch_size=self.hparams.test_batch,
shuffle=False,
**self.kwargs)
return test_loader
================================================
FILE: models_latentpred.py
================================================
import os
import glob
import torch
import shutil
import numpy as np
from torch import nn
import pytorch_lightning as pl
import torch.nn.functional as F
from torchvision import transforms
import torchvision.models as models
from collections import OrderedDict
from torch.utils.data import DataLoader
from torchvision.utils import save_image
from dataset import NeuralPhysDataset, NeuralPhysLatentDynamicsDataset
from model_utils import (EncoderDecoder,
EncoderDecoder64x1x1,
RefineDoublePendulumModel,
RefineSinglePendulumModel,
RefineCircularMotionModel,
RefineModelReLU,
RefineSwingStickNonMagneticModel,
RefineAirDancerModel,
RefineLavaLampModel,
RefineFireModel,
RefineElasticPendulumModel,
RefineReactionDiffusionModel,
LatentPredModel)
def mkdir(folder):
if os.path.exists(folder):
shutil.rmtree(folder)
os.makedirs(folder)
def rename_ckpt_for_multi_models(ckpt):
renamed_state_dict = OrderedDict()
for k, v in ckpt['state_dict'].items():
if k.split('.')[0] == 'model':
name = k.replace('model.', '')
renamed_state_dict[name] = v
return renamed_state_dict
class VisLatentDynamicsModel(pl.LightningModule):
def __init__(self,
lr: float=1e-4,
seed: int=1,
if_cuda: bool=True,
if_test: bool=False,
gamma: float=0.5,
log_dir: str='logs',
train_batch: int=512,
val_batch: int=256,
test_batch: int=256,
num_workers: int=8,
model_name: str='encoder-decoder-64',
data_filepath: str='data',
dataset: str='single_pendulum',
lr_schedule: list=[20, 50, 100]) -> None:
super().__init__()
self.save_hyperparameters()
self.kwargs = {'num_workers': self.hparams.num_workers, 'pin_memory': True} if self.hparams.if_cuda else {}
# create visualization saving folder if testing
self.pred_log_dir = os.path.join(self.hparams.log_dir, 'predictions')
self.var_log_dir = os.path.join(self.hparams.log_dir, 'variables')
if not self.hparams.if_test:
mkdir(self.pred_log_dir)
mkdir(self.var_log_dir)
self.__build_model()
def __build_model(self):
# model
if self.hparams.model_name == 'latent-prediction' and self.hparams.dataset == 'single_pendulum':
self.model = LatentPredModel(in_channels=2)
self.high_dim_model = EncoderDecoder64x1x1(in_channels=3)
self.refine_model = RefineSinglePendulumModel(in_channels=64)
if self.hparams.model_name == 'latent-prediction' and self.hparams.dataset == 'double_pendulum':
self.model = LatentPredModel(in_channels=4)
self.high_dim_model = EncoderDecoder64x1x1(in_channels=3)
self.refine_model = RefineDoublePendulumModel(in_channels=64)
if self.hparams.model_name == 'latent-prediction' and self.hparams.dataset == 'elastic_pendulum':
self.model = LatentPredModel(in_channels=6)
self.high_dim_model = EncoderDecoder64x1x1(in_channels=3)
self.refine_model = RefineElasticPendulumModel(in_channels=64)
if self.hparams.model_name == 'latent-prediction' and self.hparams.dataset == 'swingstick_non_magnetic':
self.model = LatentPredModel(in_channels=4)
self.high_dim_model = EncoderDecoder64x1x1(in_channels=3)
self.refine_model = RefineSwingStickNonMagneticModel(in_channels=64)
if self.hparams.model_name == 'latent-prediction' and self.hparams.dataset == 'air_dancer':
self.model = LatentPredModel(in_channels=8)
self.high_dim_model = EncoderDecoder64x1x1(in_channels=3)
self.refine_model = RefineAirDancerModel(in_channels=64)
if self.hparams.model_name == 'latent-prediction' and self.hparams.dataset == 'lava_lamp':
self.model = LatentPredModel(in_channels=8)
self.high_dim_model = EncoderDecoder64x1x1(in_channels=3)
self.refine_model = RefineLavaLampModel(in_channels=64)
if self.hparams.model_name == 'latent-prediction' and self.hparams.dataset == 'fire':
self.model = LatentPredModel(in_channels=24)
self.high_dim_model = EncoderDecoder64x1x1(in_channels=3)
self.refine_model = RefineFireModel(in_channels=64)
if self.hparams.model_name == 'latent-prediction' and self.hparams.dataset == 'reaction_diffusion':
self.model = LatentPredModel(in_channels=2)
self.high_dim_model = EncoderDecoder64x1x1(in_channels=3)
self.refine_model = RefineReactionDiffusionModel(in_channels=64)
# loss
self.loss_func = nn.MSELoss()
def load_model(self, checkpoint_filepath):
# load model for test
checkpoint_filepath = glob.glob(os.path.join(checkpoint_filepath, '*.ckpt'))[0]
ckpt = torch.load(checkpoint_filepath)
ckpt = rename_ckpt_for_multi_models(ckpt)
self.model.load_state_dict(ckpt)
for p in self.model.parameters():
p.requires_grad = False
self.model.eval()
def load_high_dim_refine_model(self, high_dim_checkpoint_filepath, refine_checkpoint_filepath):
# load high-dim and refine models
high_dim_checkpoint_filepath = glob.glob(os.path.join(high_dim_checkpoint_filepath, '*.ckpt'))[0]
ckpt = torch.load(high_dim_checkpoint_filepath)
ckpt = rename_ckpt_for_multi_models(ckpt)
self.high_dim_model.load_state_dict(ckpt)
refine_checkpoint_filepath = glob.glob(os.path.join(refine_checkpoint_filepath, '*.ckpt'))[0]
ckpt = torch.load(refine_checkpoint_filepath)
ckpt = rename_ckpt_for_multi_models(ckpt)
self.refine_model.load_state_dict(ckpt)
for p in self.high_dim_model.parameters():
p.requires_grad = False
self.high_dim_model.eval()
for p in self.refine_model.parameters():
p.requires_grad = False
self.refine_model.eval()
def extract_decoder_from_refine_model(self):
_layers = list(self.refine_model.children())[4:]
self.refine_model_decoder = torch.nn.Sequential(*_layers)
for p in self.refine_model_decoder.parameters():
p.requires_grad = False
self.refine_model_decoder.eval()
def data_to_state(self, data):
# (B, 3, 128, 256) -> (B, 64, 1, 1) -> (B, 64) -> (B, ID)
_, latent = self.high_dim_model(data, data, False)
latent = latent.squeeze(-1).squeeze(-1)
_, state = self.refine_model(latent)
return state
def training_step(self, batch, batch_idx):
data, target, filepath = batch
data_state = self.data_to_state(data)
target_state = self.data_to_state(target)
output_state = self.model(data_state)
train_loss = self.loss_func(output_state, target_state)
self.log('train_loss', train_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
return train_loss
def validation_step(self, batch, batch_idx):
data, target, filepath = batch
data_state = self.data_to_state(data)
target_state = self.data_to_state(target)
output_state = self.model(data_state)
val_loss = self.loss_func(output_state, target_state)
self.log('val_loss', val_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
return val_loss
def test_step(self, batch, batch_idx):
data, target, target_target, filepath = batch
data_state = self.data_to_state(data)
target_state = self.data_to_state(target)
output_state = self.model(data_state)
latent_reconstructed = self.refine_model_decoder(output_state)
output, _ = self.high_dim_model(data, latent_reconstructed.unsqueeze(2).unsqueeze(3), True)
# calculate losses
test_loss = self.loss_func(output_state, target_state)
pixel_reconstruction_loss = self.loss_func(output, target_target)
self.log('test_loss', test_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
self.log('pixel_reconstruction_loss', pixel_reconstruction_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
# save the output images
for idx in range(data.shape[0]):
comparison = torch.cat([data[idx, :, :, :128].unsqueeze(0),
data[idx, :, :, 128:].unsqueeze(0),
target[idx, :, :, :128].unsqueeze(0),
target[idx, :, :, 128:].unsqueeze(0),
target_target[idx, :, :, :128].unsqueeze(0),
target_target[idx, :, :, 128:].unsqueeze(0),
output[idx, :, :, :128].unsqueeze(0),
output[idx, :, :, 128:].unsqueeze(0)])
save_image(comparison.cpu(), os.path.join(self.pred_log_dir, filepath[idx]), nrow=1)
def configure_optimizers(self):
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.parameters()), lr=self.hparams.lr)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=self.hparams.lr_schedule, gamma=self.hparams.gamma)
return [optimizer], [scheduler]
def setup(self, stage=None):
if stage == 'fit':
self.train_dataset = NeuralPhysDataset(data_filepath=self.hparams.data_filepath,
flag='train',
seed=self.hparams.seed,
object_name=self.hparams.dataset)
self.val_dataset = NeuralPhysDataset(data_filepath=self.hparams.data_filepath,
flag='val',
seed=self.hparams.seed,
object_name=self.hparams.dataset)
if stage == 'test':
self.test_dataset = NeuralPhysLatentDynamicsDataset(data_filepath=self.hparams.data_filepath,
flag='test',
seed=self.hparams.seed,
object_name=self.hparams.dataset)
def train_dataloader(self):
train_loader = torch.utils.data.DataLoader(dataset=self.train_dataset,
batch_size=self.hparams.train_batch,
shuffle=True,
**self.kwargs)
return train_loader
def val_dataloader(self):
val_loader = torch.utils.data.DataLoader(dataset=self.val_dataset,
batch_size=self.hparams.val_batch,
shuffle=False,
**self.kwargs)
return val_loader
def test_dataloader(self):
test_loader = torch.utils.data.DataLoader(dataset=self.test_dataset,
batch_size=self.hparams.test_batch,
shuffle=False,
**self.kwargs)
return test_loader
================================================
FILE: pred.py
================================================
"""
A. Long-term future prediction (model rollout)
1. encoder-decoder (0, 1 -> 8192-dim latent -> 2', 3'):
- feed (2', 3') images as input to predict (4', 5') images ...
2. encoder-decoder-64 (0, 1 -> 64-dim latent -> 2', 3'):
- feed (2', 3') images as input to predict (4', 5') images ...
3. encoder-decoder-64 & refine-64 (0, 1 -> id-dim latent -> 2', 3')
- feed (2', 3') images as input to predict (4', 5') images ...
4. encoder-decoder-64 & refine-64 hybrid:
- use refine-64 model at certain prediction steps
B. Long-term future prediction with perturbation (model rollout)
"""
import os
import sys
import glob
import yaml
import json
import torch
import pprint
import shutil
import numpy as np
from PIL import Image
from tqdm import tqdm
from munch import munchify
from torchvision import transforms
from collections import OrderedDict
from models import VisDynamicsModel
from models_latentpred import VisLatentDynamicsModel
from dataset import NeuralPhysDataset
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.loggers import TensorBoardLogger
def mkdir(folder):
if os.path.exists(folder):
shutil.rmtree(folder)
os.makedirs(folder)
def load_config(filepath):
with open(filepath, 'r') as stream:
try:
trainer_params = yaml.safe_load(stream)
return trainer_params
except yaml.YAMLError as exc:
print(exc)
def seed(cfg):
torch.manual_seed(cfg.seed)
if cfg.if_cuda:
torch.cuda.manual_seed(cfg.seed)
# uncomment for strict reproducibility
# torch.set_deterministic(True)
def model_rollout():
config_filepath = str(sys.argv[2])
cfg = load_config(filepath=config_filepath)
pprint.pprint(cfg)
cfg = munchify(cfg)
seed(cfg)
seed_everything(cfg.seed)
log_dir = '_'.join([cfg.log_dir,
cfg.dataset,
cfg.model_name,
str(cfg.seed)])
model = VisDynamicsModel(lr=cfg.lr,
seed=cfg.seed,
if_cuda=cfg.if_cuda,
if_test=True,
gamma=cfg.gamma,
log_dir=log_dir,
train_batch=cfg.train_batch,
val_batch=cfg.val_batch,
test_batch=cfg.test_batch,
num_workers=cfg.num_workers,
model_name=cfg.model_name,
data_filepath=cfg.data_filepath,
dataset=cfg.dataset,
lr_schedule=cfg.lr_schedule)
# load model
if cfg.model_name == 'encoder-decoder' or cfg.model_name == 'encoder-decoder-64':
checkpoint_filepath = str(sys.argv[3])
checkpoint_filepath = glob.glob(os.path.join(checkpoint_filepath, '*.ckpt'))[0]
ckpt = torch.load(checkpoint_filepath)
model.load_state_dict(ckpt['state_dict'])
if 'refine' in cfg.model_name:
checkpoint_filepath = str(sys.argv[4])
checkpoint_filepath = glob.glob(os.path.join(checkpoint_filepath, '*.ckpt'))[0]
ckpt = torch.load(checkpoint_filepath)
ckpt = rename_ckpt_for_multi_models(ckpt)
model.model.load_state_dict(ckpt)
high_dim_checkpoint_filepath = str(sys.argv[3])
high_dim_checkpoint_filepath = glob.glob(os.path.join(high_dim_checkpoint_filepath, '*.ckpt'))[0]
ckpt = torch.load(high_dim_checkpoint_filepath)
ckpt = rename_ckpt_for_multi_models(ckpt)
model.high_dim_model.load_state_dict(ckpt)
model = model.to('cuda')
model.eval()
model.freeze()
# get all the test video ids
data_filepath_base = os.path.join(cfg.data_filepath, cfg.dataset)
with open(os.path.join('../datainfo', cfg.dataset, f'data_split_dict_{cfg.seed}.json'), 'r') as file:
seq_dict = json.load(file)
test_vid_ids = seq_dict['test']
pred_len = int(sys.argv[6])
long_term_folder = os.path.join(log_dir, 'prediction_long_term', 'model_rollout')
loss_dict = {}
if cfg.model_name == 'encoder-decoder' or cfg.model_name == 'encoder-decoder-64':
for p_vid_idx in tqdm(test_vid_ids):
vid_filepath = os.path.join(data_filepath_base, str(p_vid_idx))
total_num_frames = len(os.listdir(vid_filepath))
suf = os.listdir(vid_filepath)[0].split('.')[-1]
data = None
saved_folder = os.path.join(long_term_folder, str(p_vid_idx))
mkdir(saved_folder)
loss_lst = []
for start_frame_idx in range(total_num_frames - 3):
if start_frame_idx % 2 != 0:
continue
# take the initial input from ground truth data
if start_frame_idx % pred_len == 0:
data = [get_data(os.path.join(vid_filepath, f'{start_frame_idx}.{suf}')),
get_data(os.path.join(vid_filepath, f'{start_frame_idx+1}.{suf}'))]
data = (torch.cat(data, 2)).unsqueeze(0)
# get the target
target = [get_data(os.path.join(vid_filepath, f'{start_frame_idx+2}.{suf}')),
get_data(os.path.join(vid_filepath, f'{start_frame_idx+3}.{suf}'))]
target = (torch.cat(target, 2)).unsqueeze(0)
# feed into the model
if cfg.model_name == 'encoder-decoder':
output, latent = model.model(data.cuda())
if cfg.model_name == 'encoder-decoder-64':
output, latent = model.model(data.cuda(), data.cuda(), False)
# compute loss
loss_lst.append(float(model.loss_func(output, target.cuda()).cpu().detach().numpy()))
# save (2', 3'), (4', 5'), ...
img = tensor_to_img(output[0, :, :, :128])
img.save(os.path.join(saved_folder, f'{start_frame_idx+2}.{suf}'))
img = tensor_to_img(output[0, :, :, 128:])
img.save(os.path.join(saved_folder, f'{start_frame_idx+3}.{suf}'))
# the output becomes the input data in the next iteration
data = torch.tensor(output.cpu().detach().numpy()).float()
loss_dict[p_vid_idx] = loss_lst
# save the test loss for all the testing videos
with open(os.path.join(long_term_folder, 'test_loss.json'), 'w') as file:
json.dump(loss_dict, file, indent=4)
if 'refine' in cfg.model_name:
for p_vid_idx in tqdm(test_vid_ids):
vid_filepath = os.path.join(data_filepath_base, str(p_vid_idx))
total_num_frames = len(os.listdir(vid_filepath))
suf = os.listdir(vid_filepath)[0].split('.')[-1]
data = None
saved_folder = os.path.join(long_term_folder, str(p_vid_idx))
mkdir(saved_folder)
loss_lst = []
for start_frame_idx in range(total_num_frames - 3):
if start_frame_idx % 2 != 0:
continue
# take the initial input from ground truth data
if start_frame_idx % pred_len == 0:
data = [get_data(os.path.join(vid_filepath, f'{start_frame_idx}.{suf}')),
get_data(os.path.join(vid_filepath, f'{start_frame_idx+1}.{suf}'))]
data = (torch.cat(data, 2)).unsqueeze(0)
# get the target
target = [get_data(os.path.join(vid_filepath, f'{start_frame_idx+2}.{suf}')),
get_data(os.path.join(vid_filepath, f'{start_frame_idx+3}.{suf}'))]
target = (torch.cat(target, 2)).unsqueeze(0)
# feed into the model
_, latent = model.high_dim_model(data.cuda(), data.cuda(), False)
latent = latent.squeeze(-1).squeeze(-1)
latent_reconstructed, latent_latent = model.model(latent)
output, _ = model.high_dim_model(data.cuda(), latent_reconstructed.unsqueeze(2).unsqueeze(3), True)
# compute loss
loss_lst.append(float(model.loss_func(output, target.cuda()).cpu().detach().numpy()))
# save (2', 3'), (4', 5'), ...
img = tensor_to_img(output[0, :, :, :128])
img.save(os.path.join(saved_folder, f'{start_frame_idx+2}.{suf}'))
img = tensor_to_img(output[0, :, :, 128:])
img.save(os.path.join(saved_folder, f'{start_frame_idx+3}.{suf}'))
# the output becomes the input data in the next iteration
data = torch.tensor(output.cpu().detach().numpy()).float()
loss_dict[p_vid_idx] = loss_lst
# save the test loss for all the testing videos
with open(os.path.join(long_term_folder, 'test_loss.json'), 'w') as file:
json.dump(loss_dict, file, indent=4)
def model_rollout_hybrid(step):
config_filepath = str(sys.argv[2])
cfg = load_config(filepath=config_filepath)
pprint.pprint(cfg)
cfg = munchify(cfg)
seed(cfg)
seed_everything(cfg.seed)
if 'refine' not in cfg.model_name:
assert False, "the hybrid scheme is only supported with refine model..."
log_dir = '_'.join([cfg.log_dir,
cfg.dataset,
cfg.model_name,
str(cfg.seed)])
model = VisDynamicsModel(lr=cfg.lr,
seed=cfg.seed,
if_cuda=cfg.if_cuda,
if_test=True,
gamma=cfg.gamma,
log_dir=log_dir,
train_batch=cfg.train_batch,
val_batch=cfg.val_batch,
test_batch=cfg.test_batch,
num_workers=cfg.num_workers,
model_name=cfg.model_name,
data_filepath=cfg.data_filepath,
dataset=cfg.dataset,
lr_schedule=cfg.lr_schedule)
# load model
checkpoint_filepath = str(sys.argv[4])
checkpoint_filepath = glob.glob(os.path.join(checkpoint_filepath, '*.ckpt'))[0]
ckpt = torch.load(checkpoint_filepath)
ckpt = rename_ckpt_for_multi_models(ckpt)
model.model.load_state_dict(ckpt)
high_dim_checkpoint_filepath = str(sys.argv[3])
high_dim_checkpoint_filepath = glob.glob(os.path.join(high_dim_checkpoint_filepath, '*.ckpt'))[0]
ckpt = torch.load(high_dim_checkpoint_filepath)
ckpt = rename_ckpt_for_multi_models(ckpt)
model.high_dim_model.load_state_dict(ckpt)
model = model.to('cuda')
model.eval()
model.freeze()
# get all the test video ids
data_filepath_base = os.path.join(cfg.data_filepath, cfg.dataset)
with open(os.path.join('../datainfo', cfg.dataset, f'data_split_dict_{cfg.seed}.json'), 'r') as file:
seq_dict = json.load(file)
test_vid_ids = seq_dict['test']
pred_len = int(sys.argv[6])
long_term_folder = os.path.join(log_dir, 'prediction_long_term', f'hybrid_rollout_{step}')
loss_dict = {}
for p_vid_idx in tqdm(test_vid_ids):
vid_filepath = os.path.join(data_filepath_base, str(p_vid_idx))
total_num_frames = len(os.listdir(vid_filepath))
suf = os.listdir(vid_filepath)[0].split('.')[-1]
data = None
saved_folder = os.path.join(long_term_folder, str(p_vid_idx))
mkdir(saved_folder)
loss_lst = []
for start_frame_idx in range(total_num_frames - 3):
if start_frame_idx % 2 != 0:
continue
# take the initial input from ground truth data
if start_frame_idx % pred_len == 0:
data = [get_data(os.path.join(vid_filepath, f'{start_frame_idx}.{suf}')),
get_data(os.path.join(vid_filepath, f'{start_frame_idx+1}.{suf}'))]
data = (torch.cat(data, 2)).unsqueeze(0)
# get the target
target = [get_data(os.path.join(vid_filepath, f'{start_frame_idx+2}.{suf}')),
get_data(os.path.join(vid_filepath, f'{start_frame_idx+3}.{suf}'))]
target = (torch.cat(target, 2)).unsqueeze(0)
# feed into the model
if (start_frame_idx + 2) % (2 * step + 2) == 0:
_, latent = model.high_dim_model(data.cuda(), data.cuda(), False)
latent = latent.squeeze(-1).squeeze(-1)
latent_reconstructed, latent_latent = model.model(latent)
output, _ = model.high_dim_model(data.cuda(), latent_reconstructed.unsqueeze(2).unsqueeze(3), True)
else:
output, _ = model.high_dim_model(data.cuda(), data.cuda(), False)
# compute loss
loss_lst.append(float(model.loss_func(output, target.cuda()).cpu().detach().numpy()))
# save (2', 3'), (4', 5'), ...
img = tensor_to_img(output[0, :, :, :128])
img.save(os.path.join(saved_folder, f'{start_frame_idx+2}.{suf}'))
img = tensor_to_img(output[0, :, :, 128:])
img.save(os.path.join(saved_folder, f'{start_frame_idx+3}.{suf}'))
# the output becomes the input data in the next iteration
data = torch.tensor(output.cpu().detach().numpy()).float()
loss_dict[p_vid_idx] = loss_lst
# save the test loss for all the testing videos
with open(os.path.join(long_term_folder, 'test_loss.json'), 'w') as file:
json.dump(loss_dict, file, indent=4)
def model_rollout_perturb(perturb_type, perturb_level):
config_filepath = str(sys.argv[2])
cfg = load_config(filepath=config_filepath)
pprint.pprint(cfg)
cfg = munchify(cfg)
seed(cfg)
seed_everything(cfg.seed)
log_dir = '_'.join([cfg.log_dir,
cfg.dataset,
cfg.model_name,
str(cfg.seed)])
model = VisDynamicsModel(lr=cfg.lr,
seed=cfg.seed,
if_cuda=cfg.if_cuda,
if_test=True,
gamma=cfg.gamma,
log_dir=log_dir,
train_batch=cfg.train_batch,
val_batch=cfg.val_batch,
test_batch=cfg.test_batch,
num_workers=cfg.num_workers,
model_name=cfg.model_name,
data_filepath=cfg.data_filepath,
dataset=cfg.dataset,
lr_schedule=cfg.lr_schedule)
# load model
if cfg.model_name == 'encoder-decoder' or cfg.model_name == 'encoder-decoder-64':
checkpoint_filepath = str(sys.argv[3])
checkpoint_filepath = glob.glob(os.path.join(checkpoint_filepath, '*.ckpt'))[0]
ckpt = torch.load(checkpoint_filepath)
model.load_state_dict(ckpt['state_dict'])
if 'refine' in cfg.model_name:
checkpoint_filepath = str(sys.argv[4])
checkpoint_filepath = glob.glob(os.path.join(checkpoint_filepath, '*.ckpt'))[0]
ckpt = torch.load(checkpoint_filepath)
ckpt = rename_ckpt_for_multi_models(ckpt)
model.model.load_state_dict(ckpt)
high_dim_checkpoint_filepath = str(sys.argv[3])
high_dim_checkpoint_filepath = glob.glob(os.path.join(high_dim_checkpoint_filepath, '*.ckpt'))[0]
ckpt = torch.load(high_dim_checkpoint_filepath)
ckpt = rename_ckpt_for_multi_models(ckpt)
model.high_dim_model.load_state_dict(ckpt)
model = model.to('cuda')
model.eval()
model.freeze()
# get all the test video ids
data_filepath_base = os.path.join(cfg.data_filepath, cfg.dataset)
with open(os.path.join('../datainfo', cfg.dataset, f'data_split_dict_{cfg.seed}.json'), 'r') as file:
seq_dict = json.load(file)
test_vid_ids = seq_dict['test']
pred_len = int(sys.argv[6])
long_term_folder = os.path.join(log_dir, 'prediction_long_term', f'model_rollout_perturb_{perturb_type}_{perturb_level}')
loss_dict = {}
if cfg.model_name == 'encoder-decoder' or cfg.model_name == 'encoder-decoder-64':
for p_vid_idx in tqdm(test_vid_ids):
vid_filepath = os.path.join(data_filepath_base, str(p_vid_idx))
total_num_frames = len(os.listdir(vid_filepath))
suf = os.listdir(vid_filepath)[0].split('.')[-1]
data = None
saved_folder = os.path.join(long_term_folder, str(p_vid_idx))
mkdir(saved_folder)
loss_lst = []
for start_frame_idx in range(total_num_frames - 3):
if start_frame_idx % 2 != 0:
continue
# take the initial input from ground truth data
if start_frame_idx % pred_len == 0:
data = [get_data_perturb(os.path.join(vid_filepath, f'{start_frame_idx}.{suf}'), perturb_type, perturb_level),
get_data_perturb(os.path.join(vid_filepath, f'{start_frame_idx+1}.{suf}'), perturb_type, perturb_level)]
data = (torch.cat(data, 2)).unsqueeze(0)
img = tensor_to_img(data[0, :, :, :128])
img.save(os.path.join(saved_folder, f'{start_frame_idx}.{suf}'))
img = tensor_to_img(data[0, :, :, 128:])
img.save(os.path.join(saved_folder, f'{start_frame_idx+1}.{suf}'))
# get the target
target = [get_data(os.path.join(vid_filepath, f'{start_frame_idx+2}.{suf}')),
get_data(os.path.join(vid_filepath, f'{start_frame_idx+3}.{suf}'))]
target = (torch.cat(target, 2)).unsqueeze(0)
# feed into the model
if cfg.model_name == 'encoder-decoder':
output, latent = model.model(data.cuda())
if cfg.model_name == 'encoder-decoder-64':
output, latent = model.model(data.cuda(), data.cuda(), False)
# compute loss
loss_lst.append(float(model.loss_func(output, target.cuda()).cpu().detach().numpy()))
# save (2', 3'), (4', 5'), ...
img = tensor_to_img(output[0, :, :, :128])
img.save(os.path.join(saved_folder, f'{start_frame_idx+2}.{suf}'))
img = tensor_to_img(output[0, :, :, 128:])
img.save(os.path.join(saved_folder, f'{start_frame_idx+3}.{suf}'))
# the output becomes the input data in the next iteration
data = torch.tensor(output.cpu().detach().numpy()).float()
loss_dict[p_vid_idx] = loss_lst
# save the test loss for all the testing videos
with open(os.path.join(long_term_folder, 'test_loss.json'), 'w') as file:
json.dump(loss_dict, file, indent=4)
if 'refine' in cfg.model_name:
for p_vid_idx in tqdm(test_vid_ids):
vid_filepath = os.path.join(data_filepath_base, str(p_vid_idx))
total_num_frames = len(os.listdir(vid_filepath))
suf = os.listdir(vid_filepath)[0].split('.')[-1]
data = None
saved_folder = os.path.join(long_term_folder, str(p_vid_idx))
mkdir(saved_folder)
loss_lst = []
for start_frame_idx in range(total_num_frames - 3):
if start_frame_idx % 2 != 0:
continue
# take the initial input from ground truth data
if start_frame_idx % pred_len == 0:
data = [get_data_perturb(os.path.join(vid_filepath, f'{start_frame_idx}.{suf}'), perturb_type, perturb_level),
get_data_perturb(os.path.join(vid_filepath, f'{start_frame_idx+1}.{suf}'), perturb_type, perturb_level)]
data = (torch.cat(data, 2)).unsqueeze(0)
img = tensor_to_img(data[0, :, :, :128])
img.save(os.path.join(saved_folder, f'{start_frame_idx}.{suf}'))
img = tensor_to_img(data[0, :, :, 128:])
img.save(os.path.join(saved_folder, f'{start_frame_idx+1}.{suf}'))
# get the target
target = [get_data(os.path.join(vid_filepath, f'{start_frame_idx+2}.{suf}')),
get_data(os.path.join(vid_filepath, f'{start_frame_idx+3}.{suf}'))]
target = (torch.cat(target, 2)).unsqueeze(0)
# feed into the model
_, latent = model.high_dim_model(data.cuda(), data.cuda(), False)
latent = latent.squeeze(-1).squeeze(-1)
latent_reconstructed, latent_latent = model.model(latent)
output, _ = model.high_dim_model(data.cuda(), latent_reconstructed.unsqueeze(2).unsqueeze(3), True)
# compute loss
loss_lst.append(float(model.loss_func(output, target.cuda()).cpu().detach().numpy()))
# save (2', 3'), (4', 5'), ...
img = tensor_to_img(output[0, :, :, :128])
img.save(os.path.join(saved_folder, f'{start_frame_idx+2}.{suf}'))
img = tensor_to_img(output[0, :, :, 128:])
img.save(os.path.join(saved_folder, f'{start_frame_idx+3}.{suf}'))
# the output becomes the input data in the next iteration
data = torch.tensor(output.cpu().detach().numpy()).float()
loss_dict[p_vid_idx] = loss_lst
# save the test loss for all the testing videos
with open(os.path.join(long_term_folder, 'test_loss.json'), 'w') as file:
json.dump(loss_dict, file, indent=4)
def rename_ckpt_for_multi_models(ckpt):
renamed_state_dict = OrderedDict()
for k, v in ckpt['state_dict'].items():
if 'high_dim_model' in k:
name = k.replace('high_dim_model.', '')
else:
name = k.replace('model.', '')
renamed_state_dict[name] = v
return renamed_state_dict
def get_data(filepath):
data = Image.open(filepath)
data = data.resize((128, 128))
data = np.array(data)
data = torch.tensor(data / 255.0)
data = data.permute(2, 0, 1).float()
return data
def get_data_perturb(filepath, perturb_type, perturb_level):
data = Image.open(filepath)
data = data.resize((128, 128))
data = np.array(data)
bg_color = np.array([215, 205, 192])
rng = np.random.RandomState(int(filepath.split('/')[-2]))
new_bg_color = rng.randint(256, size=3)
if perturb_type == 'background_replace':
for i in range(2**(perturb_level-1)):
for j in range(2**(perturb_level-1)):
if np.array_equal(data[i, j], bg_color):
data[i, j] = new_bg_color
elif perturb_type == 'background_cover':
for i in range(2**(perturb_level-1)):
for j in range(2**(perturb_level-1)):
data[i, j] = new_bg_color
elif perturb_type == 'white_noise':
sigma = 255.0 * (2**(perturb_level-1) / 128) ** 2
data = data + rng.normal(0, sigma, data.shape)
else:
pass
data = torch.tensor(data / 255.0)
data = data.permute(2, 0, 1).float()
return data
# out_tensor: 3 x 128 x 128 -> 128 x 128 x 3
def tensor_to_img(out_tensor):
return transforms.ToPILImage()(out_tensor).convert("RGB")
if __name__ == '__main__':
if str(sys.argv[1]) == 'model-rollout':
model_rollout()
elif 'hybrid' in str(sys.argv[1]):
step = int(sys.argv[1].split('-')[-1])
model_rollout_hybrid(step)
elif str(sys.argv[1]) == 'latent-prediction':
latent_prediction()
elif 'perturb' in str(sys.argv[1]):
perturb_type = str(sys.argv[1].split('-')[-2])
perturb_level = int(sys.argv[1].split('-')[-1])
model_rollout_perturb(perturb_type, perturb_level)
else:
assert False, "prediction scheme is not supported..."
================================================
FILE: requirements.txt
================================================
absl-py==0.7.1
aiohttp==3.7.4.post0
aiohttp-cors==0.7.0
aioredis==1.3.1
astor==0.7.1
astunparse==1.6.3
async-timeout==3.0.1
atari-py==0.1.15
atomicwrites==1.3.0
attrs==19.3.0
audioread==2.1.6
backcall==0.1.0
backports.shutil-get-terminal-size==1.0.0
bleach==1.5.0
blessings==1.7
cachetools==4.1.1
cattrs==1.0.0
certifi==2019.3.9
chardet==3.0.4
chumpy==0.70
click==8.0.1
cloudpickle==0.8.1
colorama==0.4.4
contextvars==2.4
cycler==0.10.0
Cython==0.29.22
dataclasses==0.8
decorator==4.4.0
deepdish==0.3.6
defusedxml==0.6.0
dill==0.2.9
Django==2.2.1
docopt==0.6.2
entrypoints==0.3
filelock==3.0.12
Flask==2.0.1
flatbuffers==1.12
fsspec==0.8.4
future==0.17.1
gast==0.3.3
glob2==0.6
google-api-core==1.28.0
google-auth==1.30.1
google-auth-oauthlib==0.4.1
google-pasta==0.2.0
googleapis-common-protos==1.53.0
gpustat==0.6.0
grpcio==1.34.0
gym==0.12.5
h5py==2.10.0
hiredis==2.0.0
html5lib==0.9999999
idna==2.8
idna-ssl==1.1.0
image==1.5.25
imageio==2.5.0
immutables==0.15
importlib-metadata==4.3.0
imutils==0.5.2
ipdb==0.12
ipykernel==5.1.0
ipython==7.4.0
ipython-genutils==0.2.0
ipywidgets==7.4.2
itsdangerous==2.0.1
jedi==0.13.3
Jinja2==3.0.1
joblib==0.13.2
jsonschema==3.0.1
jupyter==1.0.0
jupyter-client==5.2.4
jupyter-console==6.0.0
jupyter-core==4.4.0
kaleido==0.2.1
Keras-Preprocessing==1.1.2
kiwisolver==1.0.1
librosa==0.6.3
llvmlite==0.28.0
Markdown==3.1
MarkupSafe==2.0.1
matplotlib==2.2.2
mistune==0.8.4
mock==3.0.5
more-itertools==7.0.0
mpi4py==3.0.3
msgpack==1.0.2
multidict==5.1.0
munch==2.5.0
munkres==1.0.12
nbconvert==5.4.1
nbformat==4.4.0
networkx==2.3
nibabel==2.4.0
notebook==5.7.8
numba==0.43.1
numexpr==2.6.9
numpy==1.19.4
nvidia-ml-py3==7.352.0
oauthlib==3.1.0
opencensus==0.7.13
opencensus-context==0.1.2
opencv-python==4.5.2.52
opt-einsum==3.3.0
packaging==20.9
pandas==0.24.2
pandocfilters==1.4.2
parso==0.3.4
pexpect==4.6.0
pickleshare==0.7.5
Pillow==7.1.2
pkg-resources==0.0.0
plotly==4.14.3
pluggy==0.11.0
prometheus-client==0.10.1
prompt-toolkit==2.0.9
protobuf==3.17.1
psutil==5.8.0
ptyprocess==0.6.0
py==1.8.0
py-spy==0.3.7
pyasn1==0.4.8
pyasn1-modules==0.2.8
pyglet==1.3.2
Pygments==2.3.1
PyOpenGL==3.1.0
pyparsing==2.3.1
pyrsistent==0.14.11
pytest==3.10.1
python-dateutil==2.8.0
pytorch-lightning==1.0.8
pytz==2019.1
PyWavelets==1.0.3
PyYAML==5.1
pyzmq==18.0.1
qtconsole==4.4.3
ray==1.3.0
redis==3.5.3
reprint==0.5.2
requests==2.21.0
requests-oauthlib==1.3.0
resampy==0.2.1
retrying==1.3.3
rsa==4.6
scikit-image==0.15.0
scikit-learn==0.20.3
scipy==1.4.1
seaborn==0.9.0
Send2Trash==1.5.0
Shapely==1.6.4.post2
six==1.16.0
sklearn==0.0
sqlparse==0.3.0
stable-baselines==2.5.1
tables==3.5.1
tabulate==0.8.9
tensorboard==2.2.2
tensorboard-plugin-wit==1.7.0
tensorboardX==2.2
termcolor==1.1.0
terminado==0.8.2
testpath==0.4.2
--find-links https://download.pytorch.org/whl/torch_stable.html
torch==1.7.0
torchfile==0.1.0
torchvision==0.8.1
tornado==6.0.2
tqdm==4.54.1
traitlets==4.3.2
typing-extensions==3.7.4.3
urllib3==1.24.1
visdom==0.1.8.8
wcwidth==0.1.7
webencodings==0.5.1
websocket-client==0.56.0
Werkzeug==2.0.1
widgetsnbextension==3.4.2
wrapt==1.12.1
yarl==1.6.3
zipp==3.4.1
zmq==0.0.0
================================================
FILE: scripts/encoder_decoder_64_eval.sh
================================================
#!/bin/bash
dataset=$1
gpu=$2
echo "============================================================================================="
echo "============== Evaluating encoder-decoder-64 model on: $dataset (gpu id: $gpu) =============="
echo "============================================================================================="
screen -S eval-"$dataset"-64 -dm bash -c "CUDA_VISIBLE_DEVICES="$gpu" python ../eval.py ../configs/"$dataset"/model64/config1.yaml ./logs_"$dataset"_encoder-decoder-64_1/lightning_logs/checkpoints NA eval-eval NA; \
CUDA_VISIBLE_DEVICES="$gpu" python ../eval.py ../configs/"$dataset"/model64/config2.yaml ./logs_"$dataset"_encoder-decoder-64_2/lightning_logs/checkpoints NA eval-eval NA; \
CUDA_VISIBLE_DEVICES="$gpu" python ../eval.py ../configs/"$dataset"/model64/config3.yaml ./logs_"$dataset"_encoder-decoder-64_3/lightning_logs/checkpoints NA eval-eval NA; \
exec sh";
================================================
FILE: scripts/encoder_decoder_64_eval_gather.sh
================================================
#!/bin/bash
dataset=$1
gpu=$2
echo "========================================================================================================="
echo "============== Evaluating (Gathering) encoder-decoder-64 model on: $dataset (gpu id: $gpu) =============="
echo "========================================================================================================="
screen -S eval-"$dataset"-64 -dm bash -c "CUDA_VISIBLE_DEVICES="$gpu" python ../eval.py ../configs/"$dataset"/model64/config1.yaml ./logs_"$dataset"_encoder-decoder-64_1/lightning_logs/checkpoints NA eval-train NA; \
CUDA_VISIBLE_DEVICES="$gpu" python ../eval.py ../configs/"$dataset"/model64/config2.yaml ./logs_"$dataset"_encoder-decoder-64_2/lightning_logs/checkpoints NA eval-train NA; \
CUDA_VISIBLE_DEVICES="$gpu" python ../eval.py ../configs/"$dataset"/model64/config3.yaml ./logs_"$dataset"_encoder-decoder-64_3/lightning_logs/checkpoints NA eval-train NA; \
exec sh";
================================================
FILE: scripts/encoder_decoder_64_long_term_model_rollout.sh
================================================
#!/bin/bash
dataset=$1
gpu=$2
echo "=========================================================================================================="
echo "============== Long-term model rollout encoder-decoder-64 model on: $dataset (gpu id: $gpu) =============="
echo "=========================================================================================================="
screen -S eval-"$dataset"-longterm-modelrollout -dm bash -c "CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout ../configs/"$dataset"/model64/config1.yaml ./logs_"$dataset"_encoder-decoder-64_1/lightning_logs/checkpoints NA NA 60; \
CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout ../configs/"$dataset"/model64/config2.yaml ./logs_"$dataset"_encoder-decoder-64_2/lightning_logs/checkpoints NA NA 60; \
CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout ../configs/"$dataset"/model64/config3.yaml ./logs_"$dataset"_encoder-decoder-64_3/lightning_logs/checkpoints NA NA 60; \
exec sh";
================================================
FILE: scripts/encoder_decoder_64_long_term_model_rollout_perturb_all.sh
================================================
#!/bin/bash
dataset=$1
gpu=$2
echo "=========================================================================================================="
echo "============== Long-term model rollout encoder-decoder-64 model on: $dataset (gpu id: $gpu) =============="
echo "=========================================================================================================="
screen -S eval-"$dataset"-longterm-modelrollout -dm bash -c "CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-background_cover-5 ../configs/"$dataset"/model64/config1.yaml ./logs_"$dataset"_encoder-decoder-64_1/lightning_logs/checkpoints NA NA 60; \
CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-background_cover-5 ../configs/"$dataset"/model64/config2.yaml ./logs_"$dataset"_encoder-decoder-64_2/lightning_logs/checkpoints NA NA 60; \
CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-background_cover-5 ../configs/"$dataset"/model64/config3.yaml ./logs_"$dataset"_encoder-decoder-64_3/lightning_logs/checkpoints NA NA 60; \
CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-background_cover-7 ../configs/"$dataset"/model64/config1.yaml ./logs_"$dataset"_encoder-decoder-64_1/lightning_logs/checkpoints NA NA 60; \
CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-background_cover-7 ../configs/"$dataset"/model64/config2.yaml ./logs_"$dataset"_encoder-decoder-64_2/lightning_logs/checkpoints NA NA 60; \
CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-background_cover-7 ../configs/"$dataset"/model64/config3.yaml ./logs_"$dataset"_encoder-decoder-64_3/lightning_logs/checkpoints NA NA 60; \
CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-background_cover-8 ../configs/"$dataset"/model64/config1.yaml ./logs_"$dataset"_encoder-decoder-64_1/lightning_logs/checkpoints NA NA 60; \
CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-background_cover-8 ../configs/"$dataset"/model64/config2.yaml ./logs_"$dataset"_encoder-decoder-64_2/lightning_logs/checkpoints NA NA 60; \
CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-background_cover-8 ../configs/"$dataset"/model64/config3.yaml ./logs_"$dataset"_encoder-decoder-64_3/lightning_logs/checkpoints NA NA 60; \
CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-background_replace-5 ../configs/"$dataset"/model64/config1.yaml ./logs_"$dataset"_encoder-decoder-64_1/lightning_logs/checkpoints NA NA 60; \
CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-background_replace-5 ../configs/"$dataset"/model64/config2.yaml ./logs_"$dataset"_encoder-decoder-64_2/lightning_logs/checkpoints NA NA 60; \
CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-background_replace-5 ../configs/"$dataset"/model64/config3.yaml ./logs_"$dataset"_encoder-decoder-64_3/lightning_logs/checkpoints NA NA 60; \
CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-background_replace-7 ../configs/"$dataset"/model64/config1.yaml ./logs_"$dataset"_encoder-decoder-64_1/lightning_logs/checkpoints NA NA 60; \
CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-background_replace-7 ../configs/"$dataset"/model64/config2.yaml ./logs_"$dataset"_encoder-decoder-64_2/lightning_logs/checkpoints NA NA 60; \
CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-background_replace-7 ../configs/"$dataset"/model64/config3.yaml ./logs_"$dataset"_encoder-decoder-64_3/lightning_logs/checkpoints NA NA 60; \
CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-background_replace-8 ../configs/"$dataset"/model64/config1.yaml ./logs_"$dataset"_encoder-decoder-64_1/lightning_logs/checkpoints NA NA 60; \
CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-background_replace-8 ../configs/"$dataset"/model64/config2.yaml ./logs_"$dataset"_encoder-decoder-64_2/lightning_logs/checkpoints NA NA 60; \
CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-background_replace-8 ../configs/"$dataset"/model64/config3.yaml ./logs_"$dataset"_encoder-decoder-64_3/lightning_logs/checkpoints NA NA 60; \
CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-white_noise-5 ../configs/"$dataset"/model64/config1.yaml ./logs_"$dataset"_encoder-decoder-64_1/lightning_logs/checkpoints NA NA 60; \
CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-white_noise-5 ../configs/"$dataset"/model64/config2.yaml ./logs_"$dataset"_encoder-decoder-64_2/lightning_logs/checkpoints NA NA 60; \
CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-white_noise-5 ../configs/"$dataset"/model64/config3.yaml ./logs_"$dataset"_encoder-decoder-64_3/lightning_logs/checkpoints NA NA 60; \
CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-white_noise-7 ../configs/"$dataset"/model64/config1.yaml ./logs_"$dataset"_encoder-decoder-64_1/lightning_logs/checkpoints NA NA 60; \
CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-white_noise-7 ../configs/"$dataset"/model64/config2.yaml ./logs_"$dataset"_encoder-decoder-64_2/lightning_logs/checkpoints NA NA 60; \
CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-white_noise-7 ../configs/"$dataset"/model64/config3.yaml ./logs_"$dataset"_encoder-decoder-64_3/lightning_logs/checkpoints NA NA 60; \
CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-white_noise-8 ../configs/"$dataset"/model64/config1.yaml ./logs_"$dataset"_encoder-decoder-64_1/lightning_logs/checkpoints NA NA 60; \
CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-white_noise-8 ../configs/"$dataset"/model64/config2.yaml ./logs_"$dataset"_encoder-decoder-64_2/lightning_logs/checkpoints NA NA 60; \
CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-white_noise-8 ../configs/"$dataset"/model64/config3.yaml ./logs_"$dataset"_encoder-decoder-64_3/lightning_logs/checkpoints NA NA 60; \
exec sh";
================================================
FILE: scripts/encoder_decoder_64_train.sh
================================================
#!/bin/bash
dataset=$1
gpu=$2
echo "==========================================================================================="
echo "============== Training encoder-decoder-64 model on: $dataset (gpu id: $gpu) =============="
echo "==========================================================================================="
screen -S train-"$dataset"-64 -dm bash -c "CUDA_VISIBLE_DEVICES="$gpu" python ../main.py ../configs/"$dataset"/model64/config1.yaml; \
CUDA_VISIBLE_DEVICES="$gpu" python ../main.py ../configs/"$dataset"/model64/config2.yaml; \
CUDA_VISIBLE_DEVICES="$gpu" python ../main.py ../configs/"$dataset"/model64/config3.yaml; \
exec sh";
================================================
FILE: scripts/encoder_decoder_estimate_dimension.sh
================================================
#!/bin/bash
dataset=$1
echo "============================================================================================"
echo "========== Estimating intrinsic dimension from encoder-decoder model on: $dataset =========="
echo "============================================================================================"
screen -S eval-"$dataset"-dimension -dm bash -c "OMP_NUM_THREADS=4 python ../analysis/eval_intrinsic_dimension.py ../configs/"$dataset"/model/config1.yaml model-latent NA; \
OMP_NUM_THREADS=4 python ../analysis/eval_intrinsic_dimension.py ../configs/"$dataset"/model/config2.yaml model-latent NA; \
OMP_NUM_THREADS=4 python ../analysis/eval_intrinsic_dimension.py ../configs/"$dataset"/model/config3.yaml model-latent NA; \
exec sh";
================================================
FILE: scripts/encoder_decoder_eval.sh
================================================
#!/bin/bash
dataset=$1
gpu=$2
echo "=========================================================================================="
echo "============== Evaluating encoder-decoder model on: $dataset (gpu id: $gpu) =============="
echo "=========================================================================================="
screen -S eval-"$dataset" -dm bash -c "CUDA_VISIBLE_DEVICES="$gpu" python ../eval.py ../configs/"$dataset"/model/config1.yaml ./logs_"$dataset"_encoder-decoder_1/lightning_logs/checkpoints NA eval-eval NA; \
CUDA_VISIBLE_DEVICES="$gpu" python ../eval.py ../configs/"$dataset"/model/config2.yaml ./logs_"$dataset"_encoder-decoder_2/lightning_logs/checkpoints NA eval-eval NA; \
CUDA_VISIBLE_DEVICES="$gpu" python ../eval.py ../configs/"$dataset"/model/config3.yaml ./logs_"$dataset"_encoder-decoder_3/lightning_logs/checkpoints NA eval-eval NA; \
exec sh";
================================================
FILE: scripts/encoder_decoder_eval_gather.sh
================================================
#!/bin/bash
dataset=$1
gpu=$2
echo "======================================================================================================"
echo "============== Evaluating (Gathering) encoder-decoder model on: $dataset (gpu id: $gpu) =============="
echo "======================================================================================================"
screen -S eval-"$dataset" -dm bash -c "CUDA_VISIBLE_DEVICES="$gpu" python ../eval.py ../configs/"$dataset"/model/config1.yaml ./logs_"$dataset"_encoder-decoder_1/lightning_logs/checkpoints NA eval-train NA; \
CUDA_VISIBLE_DEVICES="$gpu" python ../eval.py ../configs/"$dataset"/model/config2.yaml ./logs_"$dataset"_encoder-decoder_2/lightning_logs/checkpoints NA eval-train NA; \
CUDA_VISIBLE_DEVICES="$gpu" python ../eval.py ../configs/"$dataset"/model/config3.yaml ./logs_"$dataset"_encoder-decoder_3/lightning_logs/checkpoints NA eval-train NA; \
exec sh";
================================================
FILE: scripts/encoder_decoder_long_term_model_rollout.sh
================================================
#!/bin/bash
dataset=$1
gpu=$2
echo "======================================================================================================="
echo "============== Long-term model rollout encoder-decoder model on: $dataset (gpu id: $gpu) =============="
echo "======================================================================================================="
screen -S eval-"$dataset"-longterm-modelrollout -dm bash -c "CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout ../configs/"$dataset"/model/config1.yaml ./logs_"$dataset"_encoder-decoder_1/lightning_logs/checkpoints NA NA 60; \
CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout ../configs/"$dataset"/model/config2.yaml ./logs_"$dataset"_encoder-decoder_2/lightning_logs/checkpoints NA NA 60; \
CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout ../configs/"$dataset"/model/config3.yaml ./logs_"$dataset"_encoder-decoder_3/lightning_logs/checkpoints NA NA 60; \
exec sh";
================================================
FILE: scripts/encoder_decoder_long_term_model_rollout_perturb_all.sh
================================================
#!/bin/bash
dataset=$1
gpu=$2
echo "======================================================================================================="
echo "============== Long-term model rollout encoder-decoder model on: $dataset (gpu id: $gpu) =============="
echo "======================================================================================================="
screen -S eval-"$dataset"-longterm-modelrollout -dm bash -c "CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-background_replace-8 ../configs/"$dataset"/model/config1.yaml ./logs_"$dataset"_encoder-decoder_1/lightning_logs/checkpoints NA NA 60; \
CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-background_replace-8 ../configs/"$dataset"/model/config2.yaml ./logs_"$dataset"_encoder-decoder_2/lightning_logs/checkpoints NA NA 60; \
CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-background_replace-8 ../configs/"$dataset"/model/config3.yaml ./logs_"$dataset"_encoder-decoder_3/lightning_logs/checkpoints NA NA 60; \
CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-white_noise-5 ../configs/"$dataset"/model/config1.yaml ./logs_"$dataset"_encoder-decoder_1/lightning_logs/checkpoints NA NA 60; \
CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-white_noise-5 ../configs/"$dataset"/model/config2.yaml ./logs_"$dataset"_encoder-decoder_2/lightning_logs/checkpoints NA NA 60; \
CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-white_noise-5 ../configs/"$dataset"/model/config3.yaml ./logs_"$dataset"_encoder-decoder_3/lightning_logs/checkpoints NA NA 60; \
CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-white_noise-7 ../configs/"$dataset"/model/config1.yaml ./logs_"$dataset"_encoder-decoder_1/lightning_logs/checkpoints NA NA 60; \
CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-white_noise-7 ../configs/"$dataset"/model/config2.yaml ./logs_"$dataset"_encoder-decoder_2/lightning_logs/checkpoints NA NA 60; \
CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-white_noise-7 ../configs/"$dataset"/model/config3.yaml ./logs_"$dataset"_encoder-decoder_3/lightning_logs/checkpoints NA NA 60; \
CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-white_noise-8 ../configs/"$dataset"/model/config1.yaml ./logs_"$dataset"_encoder-decoder_1/lightning_logs/checkpoints NA NA 60; \
CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-white_noise-8 ../configs/"$dataset"/model/config2.yaml ./logs_"$dataset"_encoder-decoder_2/lightning_logs/checkpoints NA NA 60; \
CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-white_noise-8 ../configs/"$dataset"/model/config3.yaml ./logs_"$dataset"_encoder-decoder_3/lightning_logs/checkpoints NA NA 60; \
CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-background_cover-5 ../configs/"$dataset"/model/config1.yaml ./logs_"$dataset"_encoder-decoder_1/lightning_logs/checkpoints NA NA 60; \
CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-background_cover-5 ../configs/"$dataset"/model/config2.yaml ./logs_"$dataset"_encoder-decoder_2/lightning_logs/checkpoints NA NA 60; \
CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-background_cover-5 ../configs/"$dataset"/model/config3.yaml ./logs_"$dataset"_encoder-decoder_3/lightning_logs/checkpoints NA NA 60; \
CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-background_cover-7 ../configs/"$dataset"/model/config1.yaml ./logs_"$dataset"_encoder-decoder_1/lightning_logs/checkpoints NA NA 60; \
CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-background_cover-7 ../configs/"$dataset"/model/config2.yaml ./logs_"$dataset"_encoder-decoder_2/lightning_logs/checkpoints NA NA 60; \
CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-background_cover-7 ../configs/"$dataset"/model/config3.yaml ./logs_"$dataset"_encoder-decoder_3/lightning_logs/checkpoints NA NA 60; \
CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-background_cover-8 ../configs/"$dataset"/model/config1.yaml ./logs_"$dataset"_encoder-decoder_1/lightning_logs/checkpoints NA NA 60; \
CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-background_cover-8 ../configs/"$dataset"/model/config2.yaml ./logs_"$dataset"_encoder-decoder_2/lightning_logs/checkpoints NA NA 60; \
CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-background_cover-8 ../configs/"$dataset"/model/config3.yaml ./logs_"$dataset"_encoder-decoder_3/lightning_logs/checkpoints NA NA 60; \
CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-background_replace-5 ../configs/"$dataset"/model/config1.yaml ./logs_"$dataset"_encoder-decoder_1/lightning_logs/checkpoints NA NA 60; \
CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-background_replace-5 ../configs/"$dataset"/model/config2.yaml ./logs_"$dataset"_encoder-decoder_2/lightning_logs/checkpoints NA NA 60; \
CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-background_replace-5 ../configs/"$dataset"/model/config3.yaml ./logs_"$dataset"_encoder-decoder_3/lightning_logs/checkpoints NA NA 60; \
CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-background_replace-7 ../configs/"$dataset"/model/config1.yaml ./logs_"$dataset"_encoder-decoder_1/lightning_logs/checkpoints NA NA 60; \
CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-background_replace-7 ../configs/"$dataset"/model/config2.yaml ./logs_"$dataset"_encoder-decoder_2/lightning_logs/checkpoints NA NA 60; \
CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-background_replace-7 ../configs/"$dataset"/model/config3.yaml ./logs_"$dataset"_encoder-decoder_3/lightning_logs/checkpoints NA NA 60; \
exec sh";
================================================
FILE: scripts/encoder_decoder_train.sh
================================================
#!/bin/bash
dataset=$1
gpu=$2
echo "========================================================================================"
echo "============== Training encoder-decoder model on: $dataset (gpu id: $gpu) =============="
echo "========================================================================================"
screen -S train-"$dataset" -dm bash -c "CUDA_VISIBLE_DEVICES="$gpu" python ../main.py ../configs/"$dataset"/model/config1.yaml; \
CUDA_VISIBLE_DEVICES="$gpu" python ../main.py ../configs/"$dataset"/model/config2.yaml; \
CUDA_VISIBLE_DEVICES="$gpu" python ../main.py ../configs/"$dataset"/model/config3.yaml; \
exec sh";
================================================
FILE: scripts/latentpred_train.sh
================================================
#!/bin/bash
dataset=$1
gpu=$2
echo "========================================================================================"
echo "================ Training latentpred model on: $dataset (gpu id: $gpu) ================="
echo "========================================================================================"
screen -S train-"$dataset" -dm bash -c "CUDA_VISIBLE_DEVICES="$gpu" python ../main.py latentpred ../configs/"$dataset"/latentpred/config1.yaml ./logs_"$dataset"_encoder-decoder-64_1/lightning_logs/checkpoints/ ./logs_"$dataset"_refine-64_1/lightning_logs/checkpoints/; \
CUDA_VISIBLE_DEVICES="$gpu" python ../main.py latentpred ../configs/"$dataset"/latentpred/config2.yaml ./logs_"$dataset"_encoder-decoder-64_2/lightning_logs/checkpoints/ ./logs_"$dataset"_refine-64_2/lightning_logs/checkpoints/; \
CUDA_VISIBLE_DEVICES="$gpu" python ../main.py latentpred ../configs/"$dataset"/latentpred/config3.yaml ./logs_"$dataset"_encoder-decoder-64_3/lightning_logs/checkpoints/ ./logs_"$dataset"_refine-64_3/lightning_logs/checkpoints/; \
exec sh";
================================================
FILE: scripts/long_term_eval_stability.sh
================================================
#!/bin/bash
dataset=$1
gpu=$2
echo "============================================================================================"
echo "========= Evaluating stability of long term prediction on: $dataset (gpu id: $gpu) ========="
echo "============================================================================================"
screen -S eval-"$dataset"-long-term-stab -dm bash -c "CUDA_VISIBLE_DEVICES="$gpu" python ../stability.py ../configs/"$dataset"/latentpred/config1.yaml ./logs_"$dataset"_latent-prediction_1/lightning_logs/checkpoints/ ./logs_"$dataset"_encoder-decoder-64_1/lightning_logs/checkpoints/ ./logs_"$dataset"_refine-64_1/lightning_logs/checkpoints/ ./logs_"$dataset"_encoder-decoder_1/prediction_long_term/model_rollout/ 60; \
CUDA_VISIBLE_DEVICES="$gpu" python ../stability.py ../configs/"$dataset"/latentpred/config1.yaml ./logs_"$dataset"_latent-prediction_1/lightning_logs/checkpoints/ ./logs_"$dataset"_encoder-decoder-64_1/lightning_logs/checkpoints/ ./logs_"$dataset"_refine-64_1/lightning_logs/checkpoints/ ./logs_"$dataset"_encoder-decoder-64_1/prediction_long_term/model_rollout/ 60; \
CUDA_VISIBLE_DEVICES="$gpu" python ../stability.py ../configs/"$dataset"/latentpred/config1.yaml ./logs_"$dataset"_latent-prediction_1/lightning_logs/checkpoints/ ./logs_"$dataset"_encoder-decoder-64_1/lightning_logs/checkpoints/ ./logs_"$dataset"_refine-64_1/lightning_logs/checkpoints/ ./logs_"$dataset"_refine-64_1/prediction_long_term/model_rollout/ 60; \
CUDA_VISIBLE_DEVICES="$gpu" python ../stability.py ../configs/"$dataset"/latentpred/config2.yaml ./logs_"$dataset"_latent-prediction_2/lightning_logs/checkpoints/ ./logs_"$dataset"_encoder-decoder-64_2/lightning_logs/checkpoints/ ./logs_"$dataset"_refine-64_2/lightning_logs/checkpoints/ ./logs_"$dataset"_encoder-decoder_2/prediction_long_term/model_rollout/ 60; \
CUDA_VISIBLE_DEVICES="$gpu" python ../stability.py ../configs/"$dataset"/latentpred/config2.yaml ./logs_"$dataset"_latent-prediction_2/lightning_logs/checkpoints/ ./logs_"$dataset"_encoder-decoder-64_2/lightning_logs/checkpoints/ ./logs_"$dataset"_refine-64_2/lightning_logs/checkpoints/ ./logs_"$dataset"_encoder-decoder-64_2/prediction_long_term/model_rollout/ 60; \
CUDA_VISIBLE_DEVICES="$gpu" python ../stability.py ../configs/"$dataset"/latentpred/config2.yaml ./logs_"$dataset"_latent-prediction_2/lightning_logs/checkpoints/ ./logs_"$dataset"_encoder-decoder-64_2/lightning_logs/checkpoints/ ./logs_"$dataset"_refine-64_2/lightning_logs/checkpoints/ ./logs_"$dataset"_refine-64_2/prediction_long_term/model_rollout/ 60; \
CUDA_VISIBLE_DEVICES="$gpu" python ../stability.py ../configs/"$dataset"/latentpred/config3.yaml ./logs_"$dataset"_latent-prediction_3/lightning_logs/checkpoints/ ./logs_"$dataset"_encoder-decoder-64_3/lightning_logs/checkpoints/ ./logs_"$dataset"_refine-64_3/lightning_logs/checkpoints/ ./logs_"$dataset"_encoder-decoder_3/prediction_long_term/model_rollout/ 60; \
CUDA_VISIBLE_DEVICES="$gpu" python ../stability.py ../configs/"$dataset"/latentpred/config3.yaml ./logs_"$dataset"_latent-prediction_3/lightning_logs/checkpoints/ ./logs_"$dataset"_encoder-decoder-64_3/lightning_logs/checkpoints/ ./logs_"$dataset"_refine-64_3/lightning_logs/checkpoints/ ./logs_"$dataset"_encoder-decoder-64_3/prediction_long_term/model_rollout/ 60; \
CUDA_VISIBLE_DEVICES="$gpu" python ../stability.py ../configs/"$dataset"/latentpred/config3.yaml ./logs_"$dataset"_latent-prediction_3/lightning_logs/checkpoints/ ./logs_"$dataset"_encoder-decoder-64_3/lightning_logs/checkpoints/ ./logs_"$dataset"_refine-64_3/lightning_logs/checkpoints/ ./logs_"$dataset"_refine-64_3/prediction_long_term/model_rollout/ 60; \
exec sh";
================================================
FILE: scripts/long_term_eval_stability_hybrid.sh
================================================
#!/bin/bash
dataset=$1
gpu=$2
step=$3
echo "============================================================================================"
echo "========= Evaluating stability of long term prediction on: $dataset (gpu id: $gpu) ========="
echo "============================================================================================"
screen -S eval-"$dataset"-long-term-stab -dm bash -c "CUDA_VISIBLE_DEVICES="$gpu" python ../stability.py ../configs/"$dataset"/latentpred/config1.yaml ./logs_"$dataset"_latent-prediction_1/lightning_logs/checkpoints/ ./logs_"$dataset"_encoder-decoder-64_1/lightning_logs/checkpoints/ ./logs_"$dataset"_refine-64_1/lightning_logs/checkpoints/ ./logs_"$dataset"_refine-64_1/prediction_long_term/hybrid_rollout_$step/ 60; \
CUDA_VISIBLE_DEVICES="$gpu" python ../stability.py ../configs/"$dataset"/latentpred/config2.yaml ./logs_"$dataset"_latent-prediction_2/lightning_logs/checkpoints/ ./logs_"$dataset"_encoder-decoder-64_2/lightning_logs/checkpoints/ ./logs_"$dataset"_refine-64_2/lightning_logs/checkpoints/ ./logs_"$dataset"_refine-64_2/prediction_long_term/hybrid_rollout_$step/ 60; \
CUDA_VISIBLE_DEVICES="$gpu" python ../stability.py ../configs/"$dataset"/latentpred/config3.yaml ./logs_"$dataset"_latent-prediction_3/lightning_logs/checkpoints/ ./logs_"$dataset"_encoder-decoder-64_3/lightning_logs/checkpoints/ ./logs_"$dataset"_refine-64_3/lightning_logs/checkpoints/ ./logs_"$dataset"_refine-64_3/prediction_long_term/hybrid_rollout_$step/ 60; \
exec sh";
================================================
FILE: scripts/refine_64_eval.sh
================================================
#!/bin/bash
dataset=$1
gpu=$2
echo "===================================================================================="
echo "============== Evaluating refine-64 model on: $dataset (gpu id: $gpu) =============="
echo "===================================================================================="
screen -S eval-"$dataset"-refine-64 -dm bash -c "CUDA_VISIBLE_DEVICES="$gpu" python ../eval.py ../configs/"$dataset"/refine64/config1.yaml ./logs_"$dataset"_refine-64_1/lightning_logs/checkpoints ./logs_"$dataset"_encoder-decoder-64_1/lightning_logs/checkpoints eval-eval NA; \
CUDA_VISIBLE_DEVICES="$gpu" python ../eval.py ../configs/"$dataset"/refine64/config2.yaml ./logs_"$dataset"_refine-64_2/lightning_logs/checkpoints ./logs_"$dataset"_encoder-decoder-64_2/lightning_logs/checkpoints eval-eval NA; \
CUDA_VISIBLE_DEVICES="$gpu" python ../eval.py ../configs/"$dataset"/refine64/config3.yaml ./logs_"$dataset"_refine-64_3/lightning_logs/checkpoints ./logs_"$dataset"_encoder-decoder-64_3/lightning_logs/checkpoints eval-eval NA; \
exec sh";
================================================
FILE: scripts/refine_64_eval_gather.sh
================================================
#!/bin/bash
dataset=$1
gpu=$2
echo "================================================================================================"
echo "============== Evaluating (Gathering) refine-64 model on: $dataset (gpu id: $gpu) =============="
echo "================================================================================================"
screen -S eval-"$dataset"-refine-64 -dm bash -c "CUDA_VISIBLE_DEVICES="$gpu" python ../eval.py ../configs/"$dataset"/refine64/config1.yaml ./logs_"$dataset"_refine-64_1/lightning_logs/checkpoints ./logs_"$dataset"_encoder-decoder-64_1/lightning_logs/checkpoints eval-refine-train NA; \
CUDA_VISIBLE_DEVICES="$gpu" python ../eval.py ../configs/"$dataset"/refine64/config2.yaml ./logs_"$dataset"_refine-64_2/lightning_logs/checkpoints ./logs_"$dataset"_encoder-decoder-64_2/lightning_logs/checkpoints eval-refine-train NA; \
CUDA_VISIBLE_DEVICES="$gpu" python ../eval.py ../configs/"$dataset"/refine64/config3.yaml ./logs_"$dataset"_refine-64_3/lightning_logs/checkpoints ./logs_"$dataset"_encoder-decoder-64_3/lightning_logs/checkpoints eval-refine-train NA; \
exec sh";
================================================
FILE: scripts/refine_64_long_term_hybrid_rollout.sh
================================================
#!/bin/bash
dataset=$1
gpu=$2
step=$3
echo "================================================================================================"
echo "======= Long-term hybrid-$step model rollout refine-64 model on: $dataset (gpu id: $gpu) ======="
echo "================================================================================================"
screen -S eval-"$dataset"-longterm-hybridrollout -dm bash -c "CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py hybrid-$step ../configs/"$dataset"/refine64/config1.yaml ./logs_"$dataset"_encoder-decoder-64_1/lightning_logs/checkpoints ./logs_"$dataset"_refine-64_1/lightning_logs/checkpoints NA 60; \
CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py hybrid-$step ../configs/"$dataset"/refine64/config2.yaml ./logs_"$dataset"_encoder-decoder-64_2/lightning_logs/checkpoints ./logs_"$dataset"_refine-64_2/lightning_logs/checkpoints NA 60; \
CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py hybrid-$step ../configs/"$dataset"/refine64/config3.yaml ./logs_"$dataset"_encoder-decoder-64_3/lightning_logs/checkpoints ./logs_"$dataset"_refine-64_3/lightning_logs/checkpoints NA 60; \
exec sh";
================================================
FILE: scripts/refine_64_long_term_model_rollout.sh
================================================
#!/bin/bash
dataset=$1
gpu=$2
echo "================================================================================================="
echo "============== Long-term model rollout refine-64 model on: $dataset (gpu id: $gpu) =============="
echo "================================================================================================="
screen -S eval-"$dataset"-longterm-modelrollout -dm bash -c "CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout ../configs/"$dataset"/refine64/config1.yaml ./logs_"$dataset"_encoder-decoder-64_1/lightning_logs/checkpoints ./logs_"$dataset"_refine-64_1/lightning_logs/checkpoints NA 60; \
CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout ../configs/"$dataset"/refine64/config2.yaml ./logs_"$dataset"_encoder-decoder-64_2/lightning_logs/checkpoints ./logs_"$dataset"_refine-64_2/lightning_logs/checkpoints NA 60; \
CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout ../configs/"$dataset"/refine64/config3.yaml ./logs_"$dataset"_encoder-decoder-64_3/lightning_logs/checkpoints ./logs_"$dataset"_refine-64_3/lightning_logs/checkpoints NA 60; \
exec sh";
================================================
FILE: scripts/refine_64_long_term_model_rollout_perturb_all.sh
================================================
#!/bin/bash
dataset=$1
gpu=$2
echo "================================================================================================="
echo "============== Long-term model rollout refine-64 model on: $dataset (gpu id: $gpu) =============="
echo "================================================================================================="
screen -S eval-"$dataset"-longterm-modelrollout -dm bash -c "CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-background_cover-5 ../configs/"$dataset"/refine64/config1.yaml ./logs_"$dataset"_encoder-decoder-64_1/lightning_logs/checkpoints ./logs_"$dataset"_refine-64_1/lightning_logs/checkpoints NA 60; \
CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-background_cover-5 ../configs/"$dataset"/refine64/config2.yaml ./logs_"$dataset"_encoder-decoder-64_2/lightning_logs/checkpoints ./logs_"$dataset"_refine-64_2/lightning_logs/checkpoints NA 60; \
CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-background_cover-5 ../configs/"$dataset"/refine64/config3.yaml ./logs_"$dataset"_encoder-decoder-64_3/lightning_logs/checkpoints ./logs_"$dataset"_refine-64_3/lightning_logs/checkpoints NA 60; \
CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-background_cover-7 ../configs/"$dataset"/refine64/config1.yaml ./logs_"$dataset"_encoder-decoder-64_1/lightning_logs/checkpoints ./logs_"$dataset"_refine-64_1/lightning_logs/checkpoints NA 60; \
CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-background_cover-7 ../configs/"$dataset"/refine64/config2.yaml ./logs_"$dataset"_encoder-decoder-64_2/lightning_logs/checkpoints ./logs_"$dataset"_refine-64_2/lightning_logs/checkpoints NA 60; \
CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-background_cover-7 ../configs/"$dataset"/refine64/config3.yaml ./logs_"$dataset"_encoder-decoder-64_3/lightning_logs/checkpoints ./logs_"$dataset"_refine-64_3/lightning_logs/checkpoints NA 60; \
CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-background_cover-8 ../configs/"$dataset"/refine64/config1.yaml ./logs_"$dataset"_encoder-decoder-64_1/lightning_logs/checkpoints ./logs_"$dataset"_refine-64_1/lightning_logs/checkpoints NA 60; \
CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-background_cover-8 ../configs/"$dataset"/refine64/config2.yaml ./logs_"$dataset"_encoder-decoder-64_2/lightning_logs/checkpoints ./logs_"$dataset"_refine-64_2/lightning_logs/checkpoints NA 60; \
CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-background_cover-8 ../configs/"$dataset"/refine64/config3.yaml ./logs_"$dataset"_encoder-decoder-64_3/lightning_logs/checkpoints ./logs_"$dataset"_refine-64_3/lightning_logs/checkpoints NA 60; \
CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-background_replace-5 ../configs/"$dataset"/refine64/config1.yaml ./logs_"$dataset"_encoder-decoder-64_1/lightning_logs/checkpoints ./logs_"$dataset"_refine-64_1/lightning_logs/checkpoints NA 60; \
CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-background_replace-5 ../configs/"$dataset"/refine64/config2.yaml ./logs_"$dataset"_encoder-decoder-64_2/lightning_logs/checkpoints ./logs_"$dataset"_refine-64_2/lightning_logs/checkpoints NA 60; \
CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-background_replace-5 ../configs/"$dataset"/refine64/config3.yaml ./logs_"$dataset"_encoder-decoder-64_3/lightning_logs/checkpoints ./logs_"$dataset"_refine-64_3/lightning_logs/checkpoints NA 60; \
CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-background_replace-7 ../configs/"$dataset"/refine64/config1.yaml ./logs_"$dataset"_encoder-decoder-64_1/lightning_logs/checkpoints ./logs_"$dataset"_refine-64_1/lightning_logs/checkpoints NA 60; \
CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-background_replace-7 ../configs/"$dataset"/refine64/config2.yaml ./logs_"$dataset"_encoder-decoder-64_2/lightning_logs/checkpoints ./logs_"$dataset"_refine-64_2/lightning_logs/checkpoints NA 60; \
CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-background_replace-7 ../configs/"$dataset"/refine64/config3.yaml ./logs_"$dataset"_encoder-decoder-64_3/lightning_logs/checkpoints ./logs_"$dataset"_refine-64_3/lightning_logs/checkpoints NA 60; \
CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-background_replace-8 ../configs/"$dataset"/refine64/config1.yaml ./logs_"$dataset"_encoder-decoder-64_1/lightning_logs/checkpoints ./logs_"$dataset"_refine-64_1/lightning_logs/checkpoints NA 60; \
CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-background_replace-8 ../configs/"$dataset"/refine64/config2.yaml ./logs_"$dataset"_encoder-decoder-64_2/lightning_logs/checkpoints ./logs_"$dataset"_refine-64_2/lightning_logs/checkpoints NA 60; \
CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-background_replace-8 ../configs/"$dataset"/refine64/config3.yaml ./logs_"$dataset"_encoder-decoder-64_3/lightning_logs/checkpoints ./logs_"$dataset"_refine-64_3/lightning_logs/checkpoints NA 60; \
CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-white_noise-5 ../configs/"$dataset"/refine64/config1.yaml ./logs_"$dataset"_encoder-decoder-64_1/lightning_logs/checkpoints ./logs_"$dataset"_refine-64_1/lightning_logs/checkpoints NA 60; \
CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-white_noise-5 ../configs/"$dataset"/refine64/config2.yaml ./logs_"$dataset"_encoder-decoder-64_2/lightning_logs/checkpoints ./logs_"$dataset"_refine-64_2/lightning_logs/checkpoints NA 60; \
CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-white_noise-5 ../configs/"$dataset"/refine64/config3.yaml ./logs_"$dataset"_encoder-decoder-64_3/lightning_logs/checkpoints ./logs_"$dataset"_refine-64_3/lightning_logs/checkpoints NA 60; \
CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-white_noise-7 ../configs/"$dataset"/refine64/config1.yaml ./logs_"$dataset"_encoder-decoder-64_1/lightning_logs/checkpoints ./logs_"$dataset"_refine-64_1/lightning_logs/checkpoints NA 60; \
CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-white_noise-7 ../configs/"$dataset"/refine64/config2.yaml ./logs_"$dataset"_encoder-decoder-64_2/lightning_logs/checkpoints ./logs_"$dataset"_refine-64_2/lightning_logs/checkpoints NA 60; \
CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-white_noise-7 ../configs/"$dataset"/refine64/config3.yaml ./logs_"$dataset"_encoder-decoder-64_3/lightning_logs/checkpoints ./logs_"$dataset"_refine-64_3/lightning_logs/checkpoints NA 60; \
CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-white_noise-8 ../configs/"$dataset"/refine64/config1.yaml ./logs_"$dataset"_encoder-decoder-64_1/lightning_logs/checkpoints ./logs_"$dataset"_refine-64_1/lightning_logs/checkpoints NA 60; \
CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-white_noise-8 ../configs/"$dataset"/refine64/config2.yaml ./logs_"$dataset"_encoder-decoder-64_2/lightning_logs/checkpoints ./logs_"$dataset"_refine-64_2/lightning_logs/checkpoints NA 60; \
CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-white_noise-8 ../configs/"$dataset"/refine64/config3.yaml ./logs_"$dataset"_encoder-decoder-64_3/lightning_logs/checkpoints ./logs_"$dataset"_refine-64_3/lightning_logs/checkpoints NA 60; \
exec sh";
================================================
FILE: scripts/refine_64_train.sh
================================================
#!/bin/bash
dataset=$1
gpu=$2
echo "==================================================================================="
echo "============== Training refine-64 model on: $dataset (gpu id: $gpu) ==============="
echo "==================================================================================="
screen -S train-"$dataset"-refine-64 -dm bash -c "CUDA_VISIBLE_DEVICES="$gpu" python ../main.py ../configs/"$dataset"/refine64/config1.yaml; \
CUDA_VISIBLE_DEVICES="$gpu" python ../main.py ../configs/"$dataset"/refine64/config2.yaml; \
CUDA_VISIBLE_DEVICES="$gpu" python ../main.py ../configs/"$dataset"/refine64/config3.yaml; \
exec sh";
================================================
FILE: stability.py
================================================
import os
import sys
import glob
import yaml
import json
import torch
import pprint
import shutil
import numpy as np
from PIL import Image
from tqdm import tqdm
from munch import munchify
from models_latentpred import VisLatentDynamicsModel
from dataset import NeuralPhysDataset
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.loggers import TensorBoardLogger
def mkdir(folder):
if os.path.exists(folder):
shutil.rmtree(folder)
os.makedirs(folder)
def load_config(filepath):
with open(filepath, 'r') as stream:
try:
trainer_params = yaml.safe_load(stream)
return trainer_params
except yaml.YAMLError as exc:
print(exc)
def seed(cfg):
torch.manual_seed(cfg.seed)
if cfg.if_cuda:
torch.cuda.manual_seed(cfg.seed)
def get_data(filepath):
data = Image.open(filepath)
data = data.resize((128, 128))
data = np.array(data)
data = torch.tensor(data / 255.0)
data = data.permute(2, 0, 1).float()
return data
def main():
config_filepath = str(sys.argv[1])
checkpoint_filepath = str(sys.argv[2])
high_dim_checkpoint_filepath = str(sys.argv[3])
refine_checkpoint_filepath = str(sys.argv[4])
pred_save_path = str(sys.argv[5])
pred_len = int(sys.argv[6])
cfg = load_config(filepath=config_filepath)
pprint.pprint(cfg)
cfg = munchify(cfg)
seed(cfg)
seed_everything(cfg.seed)
log_dir = '_'.join([cfg.log_dir,
cfg.dataset,
cfg.model_name,
str(cfg.seed)])
model = VisLatentDynamicsModel(lr=cfg.lr,
seed=cfg.seed,
if_cuda=cfg.if_cuda,
if_test=True,
gamma=cfg.gamma,
log_dir=log_dir,
train_batch=cfg.train_batch,
val_batch=cfg.val_batch,
test_batch=cfg.test_batch,
num_workers=cfg.num_workers,
model_name=cfg.model_name,
data_filepath=cfg.data_filepath,
dataset=cfg.dataset,
lr_schedule=cfg.lr_schedule)
model.load_model(checkpoint_filepath)
model.load_high_dim_refine_model(high_dim_checkpoint_filepath, refine_checkpoint_filepath)
model.extract_decoder_from_refine_model()
model = model.to('cuda')
model.eval()
model.freeze()
# get all the test video ids
data_filepath_base = os.path.join(cfg.data_filepath, cfg.dataset)
with open(os.path.join('../datainfo', cfg.dataset, f'data_split_dict_{cfg.seed}.json'), 'r') as file:
seq_dict = json.load(file)
test_vid_ids = seq_dict['test']
errors = []
for p_vid_idx in tqdm(test_vid_ids):
data_vid_filepath = os.path.join(data_filepath_base, str(p_vid_idx))
pred_vid_filepath = os.path.join(pred_save_path, str(p_vid_idx))
total_num_frames = len(os.listdir(data_vid_filepath))
suf = os.listdir(data_vid_filepath)[0].split('.')[-1]
error_p = []
for start_frame_idx in range(total_num_frames - 3):
if start_frame_idx % 2 != 0:
continue
# get the data
if start_frame_idx % pred_len == 0:
data = [get_data(os.path.join(data_vid_filepath, f'{start_frame_idx}.{suf}')),
get_data(os.path.join(data_vid_filepath, f'{start_frame_idx+1}.{suf}'))]
data = (torch.cat(data, 2)).unsqueeze(0).cuda()
else:
data = [get_data(os.path.join(pred_vid_filepath, f'{start_frame_idx}.{suf}')),
get_data(os.path.join(pred_vid_filepath, f'{start_frame_idx+1}.{suf}'))]
data = (torch.cat(data, 2)).unsqueeze(0).cuda()
# get the target
target = [get_data(os.path.join(pred_vid_filepath, f'{start_frame_idx+2}.{suf}')),
get_data(os.path.join(pred_vid_filepath, f'{start_frame_idx+3}.{suf}'))]
target = (torch.cat(target, 2)).unsqueeze(0).cuda()
# compute latent space error
data_state = model.data_to_state(data)
target_state = model.data_to_state(target)
output_state = model.model(data_state)
error_p.append(float(model.loss_func(output_state, target_state).cpu().detach().numpy()))
errors.append(error_p)
errors = np.array(errors)
np.save(os.path.join(pred_save_path, 'stability.npy'), errors)
if __name__ == '__main__':
main()
================================================
FILE: utils/common.py
================================================
import os
import shutil
import json
import random
import numpy as np
from PIL import Image
import plotly.graph_objects as go
def mkdir(folder):
if os.path.exists(folder):
shutil.rmtree(folder)
os.makedirs(folder)
def create_random_data_splits(seed, data_filepath, object_name, ratio=0.8):
random.seed(seed)
seq_dict = {}
obj_filepath = os.path.join(data_filepath, object_name)
num_vids = int(len(os.listdir(obj_filepath)) - 1)
vid_id_lst = list(range(num_vids))
random.shuffle(vid_id_lst)
# test
start = int(num_vids * (ratio + (1 - ratio) / 2))
seq_dict['test'] = vid_id_lst[start:]
# val
start = int(num_vids * ratio)
end = int(num_vids * (ratio + (1 - ratio) / 2))
seq_dict['val'] = vid_id_lst[start:end]
# train
seq_dict['train'] = vid_id_lst[:int(num_vids * ratio)]
with open(os.path.join(f'./datainfo/{object_name}/data_split_dict_{seed}.json'), 'w') as file:
json.dump(seq_dict, file, indent=4)
def separate_eval_image(filepath):
img = Image.open(filepath)
img.crop((2, 2, 130, 130)).save('input_0.png')
img.crop((2, 132, 130, 260)).save('input_1.png')
img.crop((2, 262, 130, 390)).save('truth_0.png')
img.crop((2, 392, 130, 520)).save('truth_1.png')
img.crop((2, 522, 130, 650)).save('output_0.png')
img.crop((2, 652, 130, 780)).save('output_1.png')
def collect_long_term_pred_sample(dataset, seed, test_vid_id, hybrid_step, start=0, step=12, stop=60, suf='png'):
idv = id_value[dataset]
fps = fps_value[dataset]
save_path = f'long_term_pred_sample_{dataset}_seed_{seed}_vid_{test_vid_id}'
mkdir(save_path)
data_filepath = f'/data/kuang/visphysics_data/{dataset}/{test_vid_id}'
pred_filepaths = {'dim-8192':f'/data/kuang/logs/logs_{dataset}_encoder-decoder_{seed}/prediction_long_term/model_rollout/{test_vid_id}',
'dim-64':f'/data/kuang/logs/logs_{dataset}_encoder-decoder-64_{seed}/prediction_long_term/model_rollout/{test_vid_id}',
f'dim-{idv}':f'/data/kuang/logs/logs_{dataset}_refine-64_{seed}/prediction_long_term/model_rollout/{test_vid_id}',
f'hybrid-64-{idv}':f'/data/kuang/logs/logs_{dataset}_refine-64_{seed}/prediction_long_term/hybrid_rollout_{hybrid_step}/{test_vid_id}'}
Image.open(os.path.join(data_filepath, str(start)+'.'+suf)).save(os.path.join(save_path, 'input_0.png'))
Image.open(os.path.join(data_filepath, str(start+1)+'.'+suf)).save(os.path.join(save_path, 'input_1.png'))
for k in range(start+step, stop+1, step):
data = Image.open(os.path.join(data_filepath, str(k-1)+'.'+suf))
data.save(os.path.join(save_path, 'truth_%.1fs.png'%((k-start)/fps)))
for scheme in ['dim-8192', 'dim-64', f'dim-{idv}', f'hybrid-64-{idv}']:
pred = Image.open(os.path.join(pred_filepaths[scheme], str(k-1)+'.'+suf))
pred.save(os.path.join(save_path, scheme+'_%.1fs.png'%((k-start)/fps)))
"""
physical system parameters
"""
id_value = {'circular_motion':2, 'reaction_diffusion':2, 'single_pendulum':2,
'double_pendulum':4, 'swingstick_non_magnetic':4, 'elastic_pendulum':6,
'air_dancer':8, 'lava_lamp':8, 'fire':24}
fps_value = {'circular_motion':60, 'reaction_diffusion':5, 'single_pendulum':60,
'double_pendulum':60, 'swingstick_non_magnetic':60, 'elastic_pendulum':60,
'air_dancer':60, 'lava_lamp':2.5, 'fire':24}
"""
plot style
"""
cols = ['#df1e1e', '#f0975a', '#1b66cb', '#149650']
colorscale=[[0.0, "rgb(49,54,149)"],
[0.1111111111111111, "rgb(69,117,180)"],
[0.2222222222222222, "rgb(116,173,209)"],
[0.3333333333333333, "rgb(171,217,233)"],
[0.4444444444444444, "rgb(224,243,248)"],
[0.5555555555555556, "rgb(254,224,144)"],
[0.6666666666666666, "rgb(253,174,97)"],
[0.7777777777777778, "rgb(244,109,67)"],
[0.8888888888888888, "rgb(215,48,39)"],
[1.0, "rgb(165,0,38)"]]
def light_color(col):
if col[:3] == 'rgb':
rgb = cols[n][4:-1]
elif col[0] == '#':
rgb = str(tuple(int(col.lstrip('#')[i:i+2],16) for i in (0,2,4)))[1:-1]
return 'rgba('+rgb +',0.2)'
def update_figure(fig):
fig.update_xaxes(showgrid=False, tickfont=dict(family="sans serif", size=18, color="black"))
fig.update_yaxes(showgrid=False, tickfont=dict(family="sans serif", size=18, color="black"))
fig.update_xaxes(showline=True, linewidth=2, linecolor="black")
fig.update_yaxes(showline=True, linewidth=2, linecolor="black")
fig.update_layout(legend=go.layout.Legend(traceorder="normal",
font=dict(family="sans serif", size=18, color="black"),
))
fig.update_layout(font=dict(family="sans serif", size=24, color="black"),
paper_bgcolor='rgba(0,0,0,0)',
plot_bgcolor='rgba(0,0,0,0)',
)
def test():
fig = go.Figure()
x = np.arange(30)
for k in range(4):
fig.add_trace(go.Scatter(x=x, y=x*(k+1), mode='lines', line=dict(width=4, color=cols[k])))
update_figure(fig)
fig.write_image('test.png', scale=4)
"""
test code only
"""
if __name__ == '__main__':
test()
================================================
FILE: utils/dimension.py
================================================
import numpy as np
from scipy import stats
import os
import sys
import yaml
from munch import munchify
def load_config(filepath):
with open(filepath, 'r') as stream:
try:
trainer_params = yaml.safe_load(stream)
return trainer_params
except yaml.YAMLError as exc:
print(exc)
def calculate_intrinsic_dimension_statistics(dataset, model_type='model', if_image=False):
configs_dir = os.path.join('../configs', dataset, model_type)
if if_image:
filename = 'intrinsic_dimension_image.npy'
else:
filename = 'intrinsic_dimension.npy'
dims_all = []
for config_filepath in os.listdir(configs_dir):
cfg = load_config(filepath=os.path.join(configs_dir, config_filepath))
cfg = munchify(cfg)
log_dir = '_'.join([cfg.log_dir, cfg.dataset, cfg.model_name, str(cfg.seed)])
vars_filepath = os.path.join(log_dir, 'variables')
dims = np.load(os.path.join(vars_filepath, filename))
dims_all.append(dims)
dims_all = np.concatenate(dims_all)
dim_mean = np.mean(dims_all)
dim_std = np.std(dims_all)
print('Mean (±std):', '%.2f (±%.2f)' % (dim_mean, dim_std))
print('Confidence interval:', '(%.1f, %.1f)' % (dim_mean-1.96*dim_std, dim_mean+1.96*dim_std))
def calculate_intrinsic_dimension_statistics_all_methods(dataset, model_type='model', if_image=False):
configs_dir = os.path.join('../configs', dataset, model_type)
if if_image:
filename = 'intrinsic_dimension_image_all_methods.npy'
else:
filename = 'intrinsic_dimension_all_methods.npy'
all_methods = ['Levina_Bickel', 'MiND_ML', 'MiND_KL', 'Hein', 'CD']
dims_all = {method:[] for method in all_methods}
for config_filepath in os.listdir(configs_dir):
cfg = load_config(filepath=os.path.join(configs_dir, config_filepath))
cfg = munchify(cfg)
log_dir = '_'.join([cfg.log_dir, cfg.dataset, cfg.model_name, str(cfg.seed)])
vars_filepath = os.path.join(log_dir, 'variables')
dims = np.load(os.path.join(vars_filepath, filename), allow_pickle=True).item()
for method in all_methods:
dims_all[method].append(dims[method])
for method in all_methods:
if method not in ['Hein', 'CD']:
dims_all[method] = np.concatenate(dims_all[method])
dim_mean = np.mean(dims_all[method])
dim_std = np.std(dims_all[method])
print(method + ':', '%.2f (±%.2f)' % (dim_mean, dim_std))
if __name__ == '__main__':
dataset = str(sys.argv[1])
calculate_intrinsic_dimension_statistics(dataset, 'model')
# calculate_intrinsic_dimension_statistics(dataset, 'model', if_image=True)
# calculate_intrinsic_dimension_statistics_all_methods(dataset, 'model')