Repository: BoyuanChen/neural-state-variables Branch: main Commit: 10483d93ac8c Files: 227 Total size: 570.0 KB Directory structure: gitextract_ivg233ui/ ├── .gitignore ├── LICENSE ├── README.md ├── analysis/ │ ├── README.md │ ├── eval_intrinsic_dimension.py │ ├── eval_phys_data.py │ ├── eval_phys_double_pendulum/ │ │ ├── __init__.py │ │ ├── angle_estimator.py │ │ └── physics_estimator.py │ ├── eval_phys_elastic_pendulum/ │ │ ├── __init__.py │ │ ├── angle_estimator.py │ │ └── physics_estimator.py │ ├── eval_phys_long_term_pred.py │ ├── eval_phys_single_pendulum/ │ │ ├── __init__.py │ │ ├── angle_estimator.py │ │ └── physics_estimator.py │ ├── eval_regression.py │ ├── intrinsic_dimension_estimation/ │ │ ├── __init__.py │ │ ├── matlab_codes/ │ │ │ ├── DANCo.m │ │ │ ├── DANCoFit.m │ │ │ ├── DANCoTrain.m │ │ │ ├── DANCo_fits.mat │ │ │ ├── GetDim.mexw64 │ │ │ ├── KNN.m │ │ │ ├── MLE.m │ │ │ ├── MiND_KL.m │ │ │ ├── MiND_ML.m │ │ │ ├── demo_idEstimation.m │ │ │ ├── gauss.m │ │ │ ├── html/ │ │ │ │ └── demo_idEstimation.html │ │ │ ├── linSubspSpanOrthonormalize.m │ │ │ ├── parseParamsNamed.m │ │ │ ├── private/ │ │ │ │ ├── DANCo_estimateKL.m │ │ │ │ └── DANCo_statistics.m │ │ │ └── randsphere.m │ │ └── methods.py │ └── latent_regression/ │ ├── __init__.py │ └── regressors.py ├── configs/ │ ├── air_dancer/ │ │ ├── latentpred/ │ │ │ ├── config1.yaml │ │ │ ├── config2.yaml │ │ │ └── config3.yaml │ │ ├── model/ │ │ │ ├── config1.yaml │ │ │ ├── config2.yaml │ │ │ └── config3.yaml │ │ ├── model64/ │ │ │ ├── config1.yaml │ │ │ ├── config2.yaml │ │ │ └── config3.yaml │ │ └── refine64/ │ │ ├── config1.yaml │ │ ├── config2.yaml │ │ └── config3.yaml │ ├── circular_motion/ │ │ ├── model/ │ │ │ ├── config1.yaml │ │ │ ├── config2.yaml │ │ │ └── config3.yaml │ │ ├── model64/ │ │ │ ├── config1.yaml │ │ │ ├── config2.yaml │ │ │ └── config3.yaml │ │ └── refine64/ │ │ ├── config1.yaml │ │ ├── config2.yaml │ │ └── config3.yaml │ ├── double_pendulum/ │ │ ├── latentpred/ │ │ │ ├── config1.yaml │ │ │ ├── config2.yaml │ │ │ └── config3.yaml │ │ ├── model/ │ │ │ ├── config1.yaml │ │ │ ├── config2.yaml │ │ │ └── config3.yaml │ │ ├── model64/ │ │ │ ├── config1.yaml │ │ │ ├── config2.yaml │ │ │ └── config3.yaml │ │ └── refine64/ │ │ ├── config1.yaml │ │ ├── config2.yaml │ │ └── config3.yaml │ ├── elastic_pendulum/ │ │ ├── latentpred/ │ │ │ ├── config1.yaml │ │ │ ├── config2.yaml │ │ │ └── config3.yaml │ │ ├── model/ │ │ │ ├── config1.yaml │ │ │ ├── config2.yaml │ │ │ └── config3.yaml │ │ ├── model64/ │ │ │ ├── config1.yaml │ │ │ ├── config2.yaml │ │ │ └── config3.yaml │ │ └── refine64/ │ │ ├── config1.yaml │ │ ├── config2.yaml │ │ └── config3.yaml │ ├── fire/ │ │ ├── latentpred/ │ │ │ ├── config1.yaml │ │ │ ├── config2.yaml │ │ │ └── config3.yaml │ │ ├── model/ │ │ │ ├── config1.yaml │ │ │ ├── config2.yaml │ │ │ └── config3.yaml │ │ ├── model64/ │ │ │ ├── config1.yaml │ │ │ ├── config2.yaml │ │ │ └── config3.yaml │ │ └── refine64/ │ │ ├── config1.yaml │ │ ├── config2.yaml │ │ └── config3.yaml │ ├── lava_lamp/ │ │ ├── latentpred/ │ │ │ ├── config1.yaml │ │ │ ├── config2.yaml │ │ │ └── config3.yaml │ │ ├── model/ │ │ │ ├── config1.yaml │ │ │ ├── config2.yaml │ │ │ └── config3.yaml │ │ ├── model64/ │ │ │ ├── config1.yaml │ │ │ ├── config2.yaml │ │ │ └── config3.yaml │ │ └── refine64/ │ │ ├── config1.yaml │ │ ├── config2.yaml │ │ └── config3.yaml │ ├── reaction_diffusion/ │ │ ├── latentpred/ │ │ │ ├── config1.yaml │ │ │ ├── config2.yaml │ │ │ └── config3.yaml │ │ ├── model/ │ │ │ ├── config1.yaml │ │ │ ├── config2.yaml │ │ │ └── config3.yaml │ │ ├── model64/ │ │ │ ├── config1.yaml │ │ │ ├── config2.yaml │ │ │ └── config3.yaml │ │ └── refine64/ │ │ ├── config1.yaml │ │ ├── config2.yaml │ │ └── config3.yaml │ ├── single_pendulum/ │ │ ├── latentpred/ │ │ │ ├── config1.yaml │ │ │ ├── config2.yaml │ │ │ └── config3.yaml │ │ ├── model/ │ │ │ ├── config1.yaml │ │ │ ├── config2.yaml │ │ │ └── config3.yaml │ │ ├── model64/ │ │ │ ├── config1.yaml │ │ │ ├── config2.yaml │ │ │ └── config3.yaml │ │ └── refine64/ │ │ ├── config1.yaml │ │ ├── config2.yaml │ │ └── config3.yaml │ ├── swingstick_magnetic/ │ │ ├── model/ │ │ │ ├── config1.yaml │ │ │ ├── config2.yaml │ │ │ └── config3.yaml │ │ ├── model64/ │ │ │ ├── config1.yaml │ │ │ ├── config2.yaml │ │ │ └── config3.yaml │ │ └── refine64/ │ │ ├── config1.yaml │ │ ├── config2.yaml │ │ └── config3.yaml │ └── swingstick_non_magnetic/ │ ├── latentpred/ │ │ ├── config1.yaml │ │ ├── config2.yaml │ │ └── config3.yaml │ ├── model/ │ │ ├── config1.yaml │ │ ├── config2.yaml │ │ └── config3.yaml │ ├── model64/ │ │ ├── config1.yaml │ │ ├── config2.yaml │ │ └── config3.yaml │ └── refine64/ │ ├── config1.yaml │ ├── config2.yaml │ └── config3.yaml ├── datainfo/ │ ├── README.md │ ├── air_dancer/ │ │ ├── data_split_dict_1.json │ │ ├── data_split_dict_2.json │ │ └── data_split_dict_3.json │ ├── circular_motion/ │ │ ├── data_split_dict_1.json │ │ ├── data_split_dict_2.json │ │ └── data_split_dict_3.json │ ├── collect/ │ │ ├── circular_motion/ │ │ │ └── make_data.py │ │ ├── double_pendulum/ │ │ │ ├── README.md │ │ │ ├── convert_video.py │ │ │ ├── equalize_background.py │ │ │ └── split_data.py │ │ ├── elastic_pendulum/ │ │ │ └── make_data.py │ │ ├── fire/ │ │ │ └── split_data.py │ │ ├── lava_lamp/ │ │ │ └── split_data.py │ │ ├── reaction_diffusion/ │ │ │ ├── make_data.py │ │ │ ├── reaction_diffusion.m │ │ │ └── reaction_diffusion_rhs.m │ │ ├── single_pendulum/ │ │ │ └── make_data.py │ │ └── utils/ │ │ └── sort_vids.py │ ├── double_pendulum/ │ │ ├── data_split_dict_1.json │ │ ├── data_split_dict_2.json │ │ └── data_split_dict_3.json │ ├── elastic_pendulum/ │ │ ├── data_split_dict_1.json │ │ ├── data_split_dict_2.json │ │ └── data_split_dict_3.json │ ├── fire/ │ │ ├── data_split_dict_1.json │ │ ├── data_split_dict_2.json │ │ └── data_split_dict_3.json │ ├── lava_lamp/ │ │ ├── data_split_dict_1.json │ │ ├── data_split_dict_2.json │ │ └── data_split_dict_3.json │ ├── reaction_diffusion/ │ │ ├── data_split_dict_1.json │ │ ├── data_split_dict_2.json │ │ └── data_split_dict_3.json │ ├── single_pendulum/ │ │ ├── data_split_dict_1.json │ │ ├── data_split_dict_2.json │ │ └── data_split_dict_3.json │ ├── swingstick_magnetic/ │ │ ├── data_split_dict_1.json │ │ ├── data_split_dict_2.json │ │ └── data_split_dict_3.json │ └── swingstick_non_magnetic/ │ ├── data_split_dict_1.json │ ├── data_split_dict_2.json │ └── data_split_dict_3.json ├── dataset.py ├── eval.py ├── main.py ├── model_utils.py ├── models.py ├── models_latentpred.py ├── pred.py ├── requirements.txt ├── scripts/ │ ├── encoder_decoder_64_eval.sh │ ├── encoder_decoder_64_eval_gather.sh │ ├── encoder_decoder_64_long_term_model_rollout.sh │ ├── encoder_decoder_64_long_term_model_rollout_perturb_all.sh │ ├── encoder_decoder_64_train.sh │ ├── encoder_decoder_estimate_dimension.sh │ ├── encoder_decoder_eval.sh │ ├── encoder_decoder_eval_gather.sh │ ├── encoder_decoder_long_term_model_rollout.sh │ ├── encoder_decoder_long_term_model_rollout_perturb_all.sh │ ├── encoder_decoder_train.sh │ ├── latentpred_train.sh │ ├── long_term_eval_stability.sh │ ├── long_term_eval_stability_hybrid.sh │ ├── refine_64_eval.sh │ ├── refine_64_eval_gather.sh │ ├── refine_64_long_term_hybrid_rollout.sh │ ├── refine_64_long_term_model_rollout.sh │ ├── refine_64_long_term_model_rollout_perturb_all.sh │ └── refine_64_train.sh ├── stability.py └── utils/ ├── common.py └── dimension.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ *.*~ *~ __pycache__ scripts/logs_* ================================================ FILE: LICENSE ================================================ MIT License Copyright (c) 2021 Boyuan Chen Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: README.md ================================================ # Discovering State Variables Hidden in Experimental Data [Boyuan Chen](http://boyuanchen.com/), [Kuang Huang](https://cm3.apam.columbia.edu/people/kuang-huang), [Sunand Raghupathi](https://www.linkedin.com/in/sunand-raghupathi?trk=public_profile_browsemap), [Ishaan Chandratreya](https://www.linkedin.com/in/ishaan-chandratreya-546aa6141), [Qiang Du](https://www.engineering.columbia.edu/faculty/qiang-du), [Hod Lipson](https://www.hodlipson.com/)
Columbia University
### [Project Website](https://www.cs.columbia.edu/~bchen/neural-state-variables) | [Video](https://youtu.be/KWHXchlJzSw) | [Paper](http://arxiv.org/abs/2112.10755) ## Overview This repo contains the PyTorch implementation for paper "Discovering State Variables Hidden in Experimental Data". ![teaser](figures/teaser.gif) ## Citation If you find our paper or codebase helpful, please consider citing: ``` @article{chen2021discover, title={Discovering State Variables Hidden in Experimental Data}, author={Chen, Boyuan and Huang, Kuang and Raghupathi, Sunand and Chandratreya, Ishaan and Du, Qiang and Lipson, Hod}, journal={arXiv preprint arXiv:2112.10755}, year={2021} } ``` ## Content - [Installation](#installation) - [Logging](#logging) - [Data Preparation](#data-preparation) - [Training and Testing](#training-and-testing) - [Intrinsic Dimension Estimation](#intrinsic-dimension-estimation) - [Long-term Prediction and Stability Evaluation](#long-term-prediction-and-stability-evaluation) - [Evaluation and Analysis](#evaluation-analysis) - [License](#license) ## Installation The installation has been test on Ubuntu 18.04 with CUDA 11.0. All the experiments are performed on one GeForce RTX 2080 Ti Nvidia GPU. Create a python virtual environment and install the dependencies. ``` virtualenv -p /usr/bin/python3.6 env3.6 source env3.6/bin/activate pip install -r requirements.txt ``` **Note**: You may need to install Matlab on your computer to use some of the data collectors and intrinsic dimension estimation algorithms. ## Logingg We first introduce the naming convention of the saved files so that it is clear what will be saved and where they will be saved. 1. Log folder naming convention: ``` logs_{dataset}_{model_name}_{seed} ``` 2. Inside the logs folder, the structure and contents are: ``` \logs_{dataset}_{model_name}_{seed} \lightning_logs \checkpoints [saved checkpoint] \version_0 [training stats] \version_1 [testing stats] \predictions [testing predicted images] \prediction_long_term [long term predicted images] \variables [file id and latent vectors on testing data] \variables_train [file id and latent vectors on training data] \variables_val [file id and latent vectors on validation data] ``` ## Data Preparation We provide nine datasets with their own download links below. - [circular_motion](https://drive.google.com/file/d/19DFYqh08B2L-YX_bTmpmxYu2jpARqoUA/view?usp=sharing) (circular motion system) - [reaction_diffusion](https://drive.google.com/file/d/1LOqpB0l86lELEegX5sV9NwHkGKeHoaTN/view?usp=sharing) (reaction diffusion system) - [single_pendulum](https://drive.google.com/file/d/1rw6vrVcV3KCIaVVwKMMhoZ9Jph5u-Zdf/view?usp=sharing) (single pendulum system) - [double_pendulum](https://drive.google.com/file/d/1QEtk4JjnRysIEjtkZIKBjACu_IiTAdX6/view?usp=sharing) (rigid double pendulum system) - [elastic_pendulum](https://drive.google.com/file/d/1Y8vzawQZhzPHp6cpLFuyM2LHuhgLjH9H/view?usp=sharing) (elastic double pendulum system) - [swingstick_non_magnetic](https://drive.google.com/file/d/1BfeGW4XTFyGdyBO0G_YnSnRyJGu2WRnc/view?usp=sharing) (swing stick system) - [air_dancer](https://drive.google.com/file/d/163KvwevY1fnDnI6WiWpc5Zaz-WWwPaAS/view?usp=sharing) (air dancer system) - [lava_lamp](https://drive.google.com/file/d/1R-I2CZaJLe2D4H818-n25l54R0pbKQfo/view?usp=sharing) (lava lamp system) - [fire](https://drive.google.com/file/d/1OIH6SalwPyD_lkXBLZq6bzmKfrp-xhrI/view?usp=sharing) (fire system) Save the downloaded dataset as ```data/{dataset_name}```, where ```data``` is your customized dataset folder. **Please make sure that ```data``` is an absolute path and you need to change the ```data_filepath``` item in the ```config.yaml``` files in ```configs``` to specify your customized dataset folder**. Please refer to the [datainfo](datainfo) folder for more details about data structure and dataset collection process. ## Training and Testing Our approach involves three models: - dynamics predictive model (encoder-decoder / encoder-decoder-64) - latent reconstruction model (refine-64) - neural latent dynamics model (latentpred) 1. Navigate to the scripts folder ``` cd scripts ``` 2. Train the dynamics predictive model (encoder-decoder and encoder-decoder-64) and then save the high-dimensional latent vectors from the testing data. ``` ./encoder_decoder_64_train.sh {dataset_name} {gpu no.} ./encoder_decoder_train.sh {dataset_name} {gpu no.} ./encoder_decoder_64_eval.sh {dataset_name} {gpu no.} ./encoder_decoder_eval.sh {dataset_name} {gpu no.} ``` 3. Run forward pass on the trained encoder-deocer model and encoder-decoder-64 model to save the high-dimensional latent vectors from the training and validation data. The saved latent vectors will be used as the training and validataion data for training and validating the latent reconstruction model. ``` ./encoder_decoder_eval_gather.sh {dataset_name} {gpu no.} ./encoder_decoder_64_eval_gather.sh {dataset_name} {gpu no.} ``` 4. Before you proceed this step, please refer to the next section to obtain the system's intrinsic dimension and then come back to this step. Train the latent reconstruction model (refine-64) with the saved 64-dim latent vectors from previous steps. Then save the obtained Neural State Variables from both training and testing data. ``` ./refine_64_train.sh {dataset_name} {gpu no.} ./refine_64_eval.sh {dataset_name} {gpu no.} ./refine_64_eval_gather.sh {dataset_name} {gpu no.} ``` 5. Train the neural latent dynamics model (latentpred) with the trained models from previous steps. ``` ./latentpred_train.sh single_pendulum {dataset_name} {gpu no.} ``` ## Intrinsic Dimension Estimation With the trained dynamics predictive model, our approach provides subroutines to estimate the system's intrinsic dimension (ID) using manifold learning algorithms. The estimated intrinsic dimension will be used to decide the number of Neural State Variables and to design the latent reconstruction model. Only after this step you can proceed to train the latent reconstruction model to obtain Neural State Variables. 1. Navigate to the scripts folder ``` cd scripts ``` which is the default directory saving all models' log folders. 2. Estimate the intrinsic dimension from the saved latent vectors of the encoder-decoder models for all random seeds. ``` ./encoder_decoder_estimate_dimension.sh {dataset_name} ``` 3. Calculate the final intrinsic dimension estimated values (mean and standard deviation). ``` python ../utils/dimension.py {dataset_name} ``` ## Long-term Prediction and Stability Evaluation With all above trained models, our approach offers system long-term predictions through model rollouts as well as stability evaluation of the long-term predictions. 1. Navigate to the scripts folder ``` cd scripts ``` 2. Long-term prediction with single model rollouts. ``` ./encoder_decoder_long_term_model_rollout.sh {dataset_name} {gpu no.} ./encoder_decoder_64_long_term_model_rollout.sh {dataset_name} {gpu no.} ./refine_64_long_term_model_rollout.sh {dataset_name} {gpu no.} ``` The predictions will be saved in the ```prediction_long_term``` subfolder under the model's log folder. 3. Long-term prediction with hybrid model rollouts. ``` ./refine_64_long_term_hybrid_rollout.sh {dataset_name} {gpu no.} {step} ``` where ```step``` is the number of model rollouts via 64-dim latent vectors before a model rollout via Neural State Variables. The predictions will be saved in the ```prediction_long_term``` subfolder under the refine-64 model's log folder. 4. Long-term prediction with single model rollouts from perturbed initial frames. ``` ./encoder_decoder_long_term_model_rollout_perturb_all.sh {dataset_name} {gpu no.} ./encoder_decoder_64_long_term_model_rollout_perturb_all.sh {dataset_name} {gpu no.} ./refine_64_long_term_model_rollout_perturb_all.sh {dataset_name} {gpu no.} ``` The predictions will be saved in the ```prediction_long_term``` subfolder under the model's log folder. 3. Stability evaluation on long-term predictions with single model rollouts ``` ./long_term_eval_stability.sh {dataset_name} {gpu no.} ``` and on long-term predictions with hybrid model rollouts ``` ./long_term_eval_stability_hybrid.sh {dataset_name} {gpu no.} {step} ``` where ```step``` is as mentioned above for hybrid model rollouts. The evaluated latent space errors measuring the prediction stability will be saved in the ```stability.npy``` file under the respective ```prediction_long_term``` folders. **Note**: The default long-term prediction length is 60 (in frames). You will need to modify the scripts if you want to use a different prediction length. ## Evaluation and Analysis Please refer to the [analysis](analysis) folder for detailed instructions for physical evaluation and analysis. ## License This repository is released under the MIT license. See [LICENSE](LICENSE) for additional details. ================================================ FILE: analysis/README.md ================================================ ## Physics Evaluation on Known Systems We provide physics evaluation algorithms extracting physical variables including positions, velocities, and energies from ground truth and predicted frames for the three known systems (single pendulum, rigid double pendulum, and elastic double pendulum). You can use these algorithms to evaluate the physical accuracy of predictions. 1. Evaluate physical variables from ground truth data. ``` python eval_phys_data.py dataset_name data_filepath ``` The ```dataset_name``` can be single_pendulum, double_pendulum, or elastic_pendulum. The ```data_filepath``` is the respective data filepath, e.g., ```data/single_pendulum``` where ```data``` is your customized data folder. The results will be saved in the ```phys_vars.npy``` file under ```data_filepath```. 2. Evaluate physical variables from long-term predictions via model rollouts and compute physical errors between ground truth data and the predictions. ``` python eval_phys_long_term_pred.py config_filepath pred_save_path ``` The ```config_filepath``` is the model configuration filepath, e.g., ```../configs/single_pendulum/model/config1.yaml```. The ```pred_save_path``` is the path to the predicted images, e.g., ```../scripts/logs_single_pendulum_encoder-decoder_1/prediction_long_term/model_rollout/```. The results will be saved in the ```phys_vars.npy```, ```phys_error.npy```, and ```pixel_error.npy``` files under ```pred_save_path```. **Note**: ```eval_phys_single_pendulum```, ```eval_phys_double_pendulum```, and ```eval_phys_elastic_pendulum``` are helper packages for evaluating physical variables for the above command lines. ## Intrinsic Dimension Estimation 1. Navigate to the scripts folder ``` cd ../scripts ``` which is the default directory saving all models' log folders. 2. Estimate intrinsic dimension from model latent vectors using the Levina-Bickel method ``` python ../analysis/eval_intrinsic_dimension.py {config_filepath} model-latent NA ``` or using all methods (Levina-Bickel, MiND-ML, MiND-KL, Hein, CD) ``` python ../analysis/eval_intrinsic_dimension.py {config_filepath} model-latent all-methods ``` The ```config_filepath``` is the model configuration filepath, e.g., ```../configs/single_pendulum/model/config1.yaml```. The results will be saved in the ```intrinsic_dimension.npy``` file (using the Levina-Bickel method) or the ```intrinsic_dimension_all_methods.npy``` file (using all methods) in the ```variables``` subfolder under the model's log folder. 3. Estimate intrinsic dimension from raw data (image pairs) using the Levina-Bickel method ``` python ../analysis/eval_intrinsic_dimension.py {config_filepath} data-image NA ``` or using all methods (Levina-Bickel, MiND-ML, MiND-KL, Hein, CD) ``` python ../analysis/eval_intrinsic_dimension.py {config_filepath} data-image all-methods ``` The ```config_filepath``` is the model configuration filepath, e.g., ```../configs/single_pendulum/model/config1.yaml```. The results will be saved in the ```intrinsic_dimension_image.npy``` file (using the Levina-Bickel method) or the ```intrinsic_dimension_image_all_methods.npy``` file (using all methods) in the ```variables``` subfolder under the model's log folder. Here the model configuration only provides data filepath and test video ids. **Note**: ```intrinsic_dimension_estimation``` are helper packages providing various intrinsic dimension estimation methods. When choosing the ```all-methods``` option, the MATLAB and MATLAB Engine API for Python should be installed on the machine, see instructions [here](https://www.mathworks.com/help/matlab/matlab_external/install-the-matlab-engine-for-python.html). References of MATLAB codes: ``` [1] Gabriele Lombardi (2021). Intrinsic dimensionality estimation techniques (https://www.mathworks.com/matlabcentral/fileexchange/40112-intrinsic-dimensionality-estimation-techniques), MATLAB Central File Exchange. Retrieved October 21, 2021. [2] M. Hein and J.-Y. Audibert, Intrinsic dimensionality estimation of submanifolds in Euclidean space, Proceedings of the 22nd Internatical Conference on Machine Learning (https://www.ml.uni-saarland.de/code/IntDim/IntDim.htm). ``` ## Latent Space Regression on Known Systems 1. Navigate to the scripts folder ``` cd ../scripts ``` which is the default directory saving all models' log folders. 2. Regress physical variables from the model latent vectors ``` python ../analysis/eval_regression.py config_filepath NA ``` or from the first ```num_components``` principal components of the model latent vectors ``` python ../analysis/eval_regression.py config_filepath num_components ``` The ```config_filepath``` is the model configuration filepath, e.g., ```../configs/single_pendulum/model/config1.yaml```. The results will be saved in the ```regression_results.npy``` file (with model latent vectors) or the ```regression_results_pca_{num_components}.npy``` file (with first few principal components of model latent vectors) in the ```variables``` subfolder under the model's log folder. ================================================ FILE: analysis/eval_intrinsic_dimension.py ================================================ import numpy as np import os import sys from tqdm import tqdm from PIL import Image import json import yaml import pprint from munch import munchify from intrinsic_dimension_estimation import ID_Estimator def load_config(filepath): with open(filepath, 'r') as stream: try: trainer_params = yaml.safe_load(stream) return trainer_params except yaml.YAMLError as exc: print(exc) def remove_duplicates(X): return np.unique(X, axis=0) def eval_id_latent(vars_filepath, if_refine, if_all_methods): if if_refine: latent = np.load(os.path.join(vars_filepath, 'refine_latent.npy')) else: latent = np.load(os.path.join(vars_filepath, 'latent.npy')) print(f'Number of samples: {latent.shape[0]}; Latent dimension: {latent.shape[1]}') latent = remove_duplicates(latent) print(f'Number of samples (duplicates removed): {latent.shape[0]}') estimator = ID_Estimator() k_list = (latent.shape[0] * np.linspace(0.008, 0.016, 5)).astype('int') print(f'List of numbers of nearest neighbors: {k_list}') if if_all_methods: dims = estimator.fit_all_methods(latent, k_list) np.save(os.path.join(vars_filepath, 'intrinsic_dimension_all_methods.npy'), dims) else: dims = estimator.fit(latent, k_list) np.save(os.path.join(vars_filepath, 'intrinsic_dimension.npy'), dims) def eval_id_image(data_filepath, test_vid_ids, vars_filepath, if_all_methods): print('Reading image data...') data = [] for p_vid_idx in tqdm(test_vid_ids): vid_filepath = os.path.join(data_filepath, str(p_vid_idx)) num_frames = len(os.listdir(vid_filepath)) suf = os.listdir(vid_filepath)[0].split('.')[-1] for p_frame in range(num_frames - 3): img_list = [] for p in range(2): img = Image.open(os.path.join(vid_filepath, str(p_frame + p) + '.' + suf)) img = img.resize((128, 128)) img = np.array(img) / 255.0 img_list.append(img) data_p = np.concatenate(img_list, 1).reshape([1, -1]) data.append(data_p) data = np.concatenate(data, 0) print(f'Number of samples: {data.shape[0]}; Image dimension (flattened): {data.shape[1]}') data = remove_duplicates(data) print(f'Number of samples (duplicates removed): {data.shape[0]}') estimator = ID_Estimator() k_list = (data.shape[0] * np.linspace(0.008, 0.016, 5)).astype('int') print(f'List of numbers of nearest neighbors: {k_list}') if if_all_methods: dims = estimator.fit_all_methods(data, k_list) np.save(os.path.join(vars_filepath, 'intrinsic_dimension_image_all_methods.npy'), dims) else: dims = estimator.fit(data, k_list) np.save(os.path.join(vars_filepath, 'intrinsic_dimension_image.npy'), dims) if __name__ == '__main__': config_filepath = str(sys.argv[1]) cfg = load_config(filepath=config_filepath) pprint.pprint(cfg) cfg = munchify(cfg) if_all_methods = (str(sys.argv[3]) == 'all-methods') if str(sys.argv[2]) == 'model-latent': log_dir = '_'.join([cfg.log_dir, cfg.dataset, cfg.model_name, str(cfg.seed)]) vars_filepath = os.path.join(log_dir, 'variables') if_refine = ('refine' in cfg.model_name) dims = eval_id_latent(vars_filepath, if_refine, if_all_methods) elif str(sys.argv[2]) == 'data-image': data_filepath = os.path.join(cfg.data_filepath, cfg.dataset) with open(os.path.join('../datainfo', cfg.dataset, f'data_split_dict_{cfg.seed}.json'), 'r') as file: seq_dict = json.load(file) test_vid_ids = seq_dict['test'] log_dir = '_'.join([cfg.log_dir, cfg.dataset, cfg.model_name, str(cfg.seed)]) vars_filepath = os.path.join(log_dir, 'variables') dims = eval_id_image(data_filepath, test_vid_ids, vars_filepath, if_all_methods) else: assert False, 'Invalid arguments...' ================================================ FILE: analysis/eval_phys_data.py ================================================ import os import sys import cv2 import numpy as np from tqdm import tqdm def eval_phys_data_single_pendulum(data_filepath, num_vids, num_frms, save_path): from eval_phys_single_pendulum import eval_physics, phys_vars_list phys = {p_var:[] for p_var in phys_vars_list} for n in tqdm(range(num_vids)): seq_filepath = os.path.join(data_filepath, str(n)) frames = [] for p in range(num_frms): frame_p = cv2.imread(os.path.join(seq_filepath, str(p)+'.png')) frames.append(frame_p) phys_tmp = eval_physics(frames) for p_var in phys_vars_list: phys[p_var].append(phys_tmp[p_var]) for p_var in phys_vars_list: phys[p_var] = np.array(phys[p_var]) np.save(save_path, phys) def eval_phys_data_double_pendulum(data_filepath, num_vids, num_frms, save_path): from eval_phys_double_pendulum import eval_physics, phys_vars_list phys = {p_var:[] for p_var in phys_vars_list} for n in tqdm(range(num_vids)): seq_filepath = os.path.join(data_filepath, str(n)) frames = [] for p in range(num_frms): frame_p = cv2.imread(os.path.join(seq_filepath, str(p)+'.png')) frames.append(frame_p) phys_tmp = eval_physics(frames) for p_var in phys_vars_list: phys[p_var].append(phys_tmp[p_var]) for p_var in phys_vars_list: phys[p_var] = np.array(phys[p_var]) # remove outliers thresh_1 = np.nanpercentile(np.abs(phys['vel_theta_1']), 98) thresh_2 = np.nanpercentile(np.abs(phys['vel_theta_2']), 98) for n in range(num_vids): for p in range(num_frms): if (not np.isnan(phys['vel_theta_1'][n, p]) and np.abs(phys['vel_theta_1'][n, p]) >= thresh_1) \ or (not np.isnan(phys['vel_theta_2'][n, p]) and np.abs(phys['vel_theta_2'][n, p]) >= thresh_2): phys['vel_theta_1'][n, p] = np.nan phys['vel_theta_2'][n, p] = np.nan phys['kinetic energy'][n, p] = np.nan phys['total energy'][n, p] = np.nan np.save(save_path, phys) def eval_phys_data_elastic_pendulum(data_filepath, num_vids, num_frms, save_path): from eval_phys_elastic_pendulum import eval_physics, phys_vars_list phys = {p_var:[] for p_var in phys_vars_list} for n in tqdm(range(num_vids)): seq_filepath = os.path.join(data_filepath, str(n)) frames = [] for p in range(num_frms): frame_p = cv2.imread(os.path.join(seq_filepath, str(p)+'.png')) frames.append(frame_p) phys_tmp = eval_physics(frames) for p_var in phys_vars_list: phys[p_var].append(phys_tmp[p_var]) for p_var in phys_vars_list: phys[p_var] = np.array(phys[p_var]) # remove outliers thresh_1 = np.nanpercentile(np.abs(phys['vel_theta_1']), 98) thresh_2 = np.nanpercentile(np.abs(phys['vel_theta_2']), 98) thresh_z = np.nanpercentile(np.abs(phys['vel_z']), 98) for n in range(num_vids): for p in range(num_frms): if (not np.isnan(phys['vel_theta_1'][n, p]) and np.abs(phys['vel_theta_1'][n, p]) >= thresh_1) \ or (not np.isnan(phys['vel_theta_2'][n, p]) and np.abs(phys['vel_theta_2'][n, p]) >= thresh_2) \ or (not np.isnan(phys['vel_z'][n, p]) and np.abs(phys['vel_z'][n, p]) >= thresh_z): phys['vel_theta_1'][n, p] = np.nan phys['vel_theta_2'][n, p] = np.nan phys['vel_z'][n, p] = np.nan phys['kinetic energy'][n, p] = np.nan phys['total energy'][n, p] = np.nan np.save(save_path, phys) if __name__ == '__main__': dataset = str(sys.argv[1]) data_filepath = str(sys.argv[2]) save_path = os.path.join(data_filepath, 'phys_vars.npy') if dataset == 'single_pendulum': eval_phys_data_single_pendulum(data_filepath, 1200, 60, save_path) elif dataset == 'double_pendulum': eval_phys_data_double_pendulum(data_filepath, 1100, 60, save_path) elif dataset == 'elastic_pendulum': eval_phys_data_elastic_pendulum(data_filepath, 1200, 60, save_path) else: assert False, 'Unknown system...' ================================================ FILE: analysis/eval_phys_double_pendulum/__init__.py ================================================ from .angle_estimator import obtain_angle from .physics_estimator import * import numpy as np phys_vars_list = {'reject', 'theta_1', 'vel_theta_1', 'theta_2', 'vel_theta_2', 'kinetic energy', 'potential energy', 'total energy'} def eval_physics(frames): num_frames = len(frames) reject = np.zeros(num_frames, dtype=bool) theta_1 = np.zeros(num_frames) theta_2 = np.zeros(num_frames) # estimate angles for p in range(num_frames): reject_p, angles_p, _ = obtain_angle(frames[p]) if reject_p: reject[p] = True theta_1[p] = np.nan theta_2[p] = np.nan else: reject[p] = False theta_1[p] = angles_p[0] theta_2[p] = angles_p[1] # calculate velocities vel_theta_1 = np.zeros(num_frames) vel_theta_2 = np.zeros(num_frames) sub_ids = np.ma.clump_unmasked(np.ma.masked_array(theta_1, reject)) for ids in sub_ids: vel_theta_1[ids] = calc_velocity(theta_1[ids].copy()) vel_theta_2[ids] = calc_velocity(theta_2[ids].copy()) vel_theta_1[reject] = np.nan vel_theta_2[reject] = np.nan # calculate energies kinetic_energy, potential_energy, total_energy = calc_energy(theta_1, theta_2, vel_theta_1, vel_theta_2) # save results phys = dict.fromkeys(phys_vars_list) phys['reject'] = reject phys['theta_1'] = theta_1 phys['vel_theta_1'] = vel_theta_1 phys['theta_2'] = theta_2 phys['vel_theta_2'] = vel_theta_2 phys['kinetic energy'] = kinetic_energy phys['potential energy'] = potential_energy phys['total energy'] = total_energy return phys ================================================ FILE: analysis/eval_phys_double_pendulum/angle_estimator.py ================================================ ''' This script provides utility functions estimating angles of a double pendulum from image. Step 1. Extract each of the two pendulums from the image. Step 2. Do a rectangle fitting of each pendulum. Step 3. Estimate the angles of the pendulums. Certain images will be rejected if any pendulum does not exist, is hidden, or has a wrong shape. ''' import cv2 import os import numpy as np ''' Extract the two pendulums from the image. Args: img: double pendulum image in BGR format Returns: seg_1: segmentation of the first pendulum seg_2: segmentation of the second pendulum ''' def seg_from_img(img): # pixel thresholds (in HSV) v_min_1 = (0, 0, 0) v_max_1 = (255, 255, 140) v_min_2 = (60, 0, 0) v_max_2 = (255, 255, 255) # background color in HSV (RGB: [215, 205, 192]) bg_color = [17, 27, 215] img_hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV) # extract the first pendulum seg_1 = cv2.inRange(img_hsv, v_min_1, v_max_1) # remove the first pendulum by replacing it with background pixels img_hsv[seg_1==255] = bg_color # extract the second pendulum seg_2 = cv2.inRange(img_hsv, v_min_2, v_max_2) return seg_1, seg_2 ''' Fit pendulum in a rectangle. Args: seg: segmentation of the pendulum Returns: rej: (True/False) if the image is rejected rect: (Box2D structure) the fitted rectangle ''' def fit_pendulum(seg): # find all contours contours, _ = cv2.findContours(seg, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) # reject if no contours found if len(contours) == 0: return True, None # find the contour with the maximum area cnt = max(contours, key=cv2.contourArea) area = cv2.contourArea(cnt) # reject if the contour is too small if area < 125: return True, None # rectangle fitting rect = cv2.minAreaRect(cnt) _, (width, height), _ = rect # reject if the rectangle is too close to a square if abs(height-width) < 8: return True, None # reject if the rectangle does not properly fit the contour if height*width > 2 * area: return True, None return False, rect ''' Estimate the angle of pendulum from its fitted rectangle. Args: rect: (Box2D structure) the fitted rectangle ref_pt: reference point to determine the arrow direction Returns: angle: the estimated angle in radians in range (0, 2*pi) box: box points of the rectangle arrow: the arrow pointing along the angle ''' def estimate_angle(rect, ref_pt): # box points box = cv2.boxPoints(rect) # center, width and height (cx, cy), (width, height), _ = rect # specify the direction vector of the pendulum if width < height: v_d = box[0] - box[1] else: v_d = box[0] - box[3] # reference vector v_ref = np.array([cx, cy]) - ref_pt # choose the one from v_d and -v_d whose angle with v_ref # is less than pi if np.dot(v_d, v_ref) < 0: v_d = -v_d # counterclockwise angle from (0,-1) to v_d angle = np.arctan2(-v_d[1], v_d[0]) + np.pi/2 if angle < 0: angle += 2*np.pi arrow = ((int(cx), int(cy)), (int(cx+v_d[0]), int(cy+v_d[1]))) return angle, box, arrow ''' Obtain the angles of double pendulum from image Args: img: double pendulum image in BGR format Returns: rej: (True/False) if the image is rejected angles: the estimated angles in radians in range (0, 2*pi) img_marked: image marked with the fitted rectangles, the direction vectors and the estimated angles (BGR format) ''' def obtain_angle(img): img_marked = img.copy() seg_1, seg_2 = seg_from_img(img) rej_1, rect_1 = fit_pendulum(seg_1) rej_2, rect_2 = fit_pendulum(seg_2) if rej_1: rej = True angles = [] cv2.putText(img_marked, 'Reject arm 1', (10, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1) elif rej_2: rej = True angles = [] cv2.putText(img_marked, 'Reject arm 2', (10, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1) else: rej = False img_center = np.array([64, 64]) angle_1, box_1, arrow_1 = estimate_angle(rect_1, img_center) (cx, cy), _, _ = rect_1 pd_1_end = 2 * np.array([cx, cy]) - img_center angle_2, box_2, arrow_2 = estimate_angle(rect_2, pd_1_end) angles = [angle_1, angle_2] # mark the fitted rectangles cv2.drawContours(img_marked, [np.int0(box_1)], 0, (0,0,255), 2) cv2.drawContours(img_marked, [np.int0(box_2)], 0, (0,0,255), 2) # mark the direction vectors cv2.arrowedLine(img_marked, arrow_1[0], arrow_1[1], (0, 0, 255), 1, tipLength=0.25) cv2.arrowedLine(img_marked, arrow_2[0], arrow_2[1], (0, 0, 255), 1, tipLength=0.25) # mark the estimated angles in degrees text = str(round(angle_1*180/np.pi)) + ' ' + str(round(angle_2*180/np.pi)) cv2.putText(img_marked, text, (10, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1) return rej, angles, img_marked ================================================ FILE: analysis/eval_phys_double_pendulum/physics_estimator.py ================================================ ''' This script provides utility functions estimating angular velocities and other physical quantities from angles of double pendulum. ''' import numpy as np from scipy.interpolate import CubicSpline # physical parameters fps = 60 # frames per second L1, L2 = 0.205, 0.179 # pendulum rod lengths (m) w1, w2 = 0.038, 0.038 # pendulum rod widths (m) m1, m2 = 0.262, 0.110 # bob masses (kg) g = 9.81 # gravitational acceleration (m/s^2) ''' Calculate the absolute difference between two angles th1 and th2 on a circle. Assumed that the absolute difference between the two angles is within range (0,pi). ''' def calc_diff(th1, th2): diff = np.abs(th2 - th1) diff = np.minimum(diff, 2*np.pi-diff) return diff ''' Calculate the average of two angles th1 and th2 on a circle. Assumed that the absolute difference between the two angles is within range (0,pi). ''' def calc_avrg(th1, th2): avrg = (th1 + th2) / 2 diff = np.abs(th2 - th1) if diff > np.pi: avrg -= np.pi if avrg < 0: avrg += 2*np.pi return avrg ''' Calculate angular velocities from a sequence of angles using numerical differentiation. method='fd': finite difference; method='spline': cubic spline fitting. ''' def calc_velocity(th, method='spline'): len_seq = th.shape[0] # isolated data if len_seq == 1: return np.nan # preprocessing: periodic extension of angles for i in range(1, len_seq): if th[i] - th[i-1] > np.pi: th[i:] -= 2*np.pi elif th[i] - th[i-1] < -np.pi: th[i:] += 2*np.pi vel_th = np.zeros(len_seq) # finite difference if method == 'fd': for i in range(1, len_seq): vel_th[i] = (th[i] - th[i-1]) * fps vel_th[0] = (th[1] - th[0]) * fps # cubic spline fitting elif method == 'spline': t = np.arange(len_seq) / fps cs = CubicSpline(t, th) vel_th = cs(t, 1) # use finite difference at boundary points to improve accuracy vel_th[0] = (th[1] - th[0]) * fps vel_th[-1] = (th[-1] - th[-2]) * fps else: assert False, 'Unrecognizable differentiation method!' return vel_th ''' Calculate energies from angles and angular velocities ''' def calc_energy(th1, th2, vel_th1, vel_th2): # centers of masses in x-y coordinates x1 = (L1 / 2) * np.sin(th1) y1 = (-L1 / 2) * np.cos(th1) x2 = (L1 * np.sin(th1)) + (L2 / 2) * np.sin(th2) y2 = (-L1 * np.cos(th1)) - (L2 / 2) * np.cos(th2) # velocities in x-y coordinates vel_x1 = vel_th1 * (L1 / 2) * np.cos(th1) vel_y1 = vel_th1 * (L1 / 2) * np.sin(th1) vel_x2 = vel_th1 * L1 * np.cos(th1) + (vel_th2 / 2) * L2 * np.cos(th2) vel_y2 = vel_th1 * L1 * np.sin(th1) + (vel_th2 / 2) * L2 * np.sin(th2) # moments of inertia I1 = (1. / 12.) * m1 * (w1 ** 2 + L1 ** 2) I2 = (1. / 12.) * m2 * (w2 ** 2 + L2 ** 2) # potential energy V = g * (m1 * y1 + m2 * y2) # kinetic energy (translation + rotation) T = 0.5 * m1 * (vel_x1 ** 2 + vel_y1 ** 2) + 0.5 * m2 * (vel_x2 ** 2 + vel_y2 ** 2) + 0.5 * I1 * ( vel_th1 ** 2) + 0.5 * I2 * (vel_th2 ** 2) # total energy E = T + V return T, V, E ================================================ FILE: analysis/eval_phys_elastic_pendulum/__init__.py ================================================ from .angle_estimator import obtain_angle_stretch from .physics_estimator import * import numpy as np phys_vars_list = {'reject', 'theta_1', 'vel_theta_1', 'theta_2', 'vel_theta_2', 'z', 'vel_z', 'kinetic energy', 'potential energy', 'total energy'} def eval_physics(frames): num_frames = len(frames) reject = np.zeros(num_frames, dtype=bool) theta_1 = np.zeros(num_frames) theta_2 = np.zeros(num_frames) z = np.zeros(num_frames) # estimate angles and stretch for p in range(num_frames): reject_p, angles_p, stretch_p, _ = obtain_angle_stretch(frames[p]) if reject_p: reject[p] = True theta_1[p] = np.nan theta_2[p] = np.nan z[p] = np.nan else: reject[p] = False theta_1[p] = angles_p[0] theta_2[p] = angles_p[1] z[p] = stretch_p # calculate velocities vel_theta_1 = np.zeros(num_frames) vel_theta_2 = np.zeros(num_frames) vel_z = np.zeros(num_frames) sub_ids = np.ma.clump_unmasked(np.ma.masked_array(theta_1, reject)) for ids in sub_ids: vel_theta_1[ids] = calc_velocity(theta_1[ids].copy()) vel_theta_2[ids] = calc_velocity(theta_2[ids].copy()) vel_z[ids] = calc_velocity(z[ids].copy(), periodic=False) vel_theta_1[reject] = np.nan vel_theta_2[reject] = np.nan vel_z[reject] = np.nan # calculate energies kinetic_energy, potential_energy, total_energy = calc_energy(theta_1, theta_2, z, vel_theta_1, vel_theta_2, vel_z) # save results phys = dict.fromkeys(phys_vars_list) phys['reject'] = reject phys['theta_1'] = theta_1 phys['vel_theta_1'] = vel_theta_1 phys['theta_2'] = theta_2 phys['vel_theta_2'] = vel_theta_2 phys['z'] = z phys['vel_z'] = vel_z phys['kinetic energy'] = kinetic_energy phys['potential energy'] = potential_energy phys['total energy'] = total_energy return phys ================================================ FILE: analysis/eval_phys_elastic_pendulum/angle_estimator.py ================================================ ''' This script provides utility functions estimating angles and stretch of a elastic double pendulum from image. Step 1. Extract each of the two pendulums from the image. Step 2. Do a rectangle fitting of each pendulum. Step 3. Estimate the angles of the pendulums and the stretch of the elastic one. Certain images will be rejected if any pendulum does not exist, is hidden, or has a wrong shape. ''' import cv2 import os import numpy as np ''' Extract the two pendulums from the image. Args: img: elastic double pendulum image in BGR format Returns: seg_1: segmentation of the first pendulum seg_2: segmentation of the second pendulum ''' def seg_from_img(img): # pixel thresholds (in HSV) v_min_1 = (0, 0, 0) v_max_1 = (255, 255, 140) v_min_2 = (60, 0, 0) v_max_2 = (255, 255, 255) # background color in HSV (RGB: [215, 205, 192]) bg_color = [17, 27, 215] img_hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV) # extract the first pendulum seg_1 = cv2.inRange(img_hsv, v_min_1, v_max_1) # remove the first pendulum by replacing it with background pixels img_hsv[seg_1==255] = bg_color # extract the second pendulum seg_2 = cv2.inRange(img_hsv, v_min_2, v_max_2) return seg_1, seg_2 ''' Fit pendulum in a rectangle. Args: seg: segmentation of the pendulum Returns: rej: (True/False) if the image is rejected rect: (Box2D structure) the fitted rectangle ''' def fit_pendulum(seg): # find all contours contours, _ = cv2.findContours(seg, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) # reject if no contours found if len(contours) == 0: return True, None # find the contour with the maximum area cnt = max(contours, key=cv2.contourArea) area = cv2.contourArea(cnt) # reject if the contour is too small if area < 60: return True, None # rectangle fitting rect = cv2.minAreaRect(cnt) _, (width, height), _ = rect # reject if the rectangle is too close to a square if abs(height-width) < 4: return True, None # reject if the rectangle does not properly fit the contour if height*width > 1.8 * area: return True, None return False, rect ''' Estimate the angle of pendulum from its fitted rectangle. Args: rect: (Box2D structure) the fitted rectangle ref_pt: reference point to determine the arrow direction Returns: angle: the estimated angle in radians in range (0, 2*pi) box: box points of the rectangle arrow: the arrow pointing along the angle ''' def estimate_angle(rect, ref_pt): # box points box = cv2.boxPoints(rect) # center, width and height (cx, cy), (width, height), _ = rect # specify the direction vector of the pendulum if width < height: v_d = box[0] - box[1] else: v_d = box[0] - box[3] # reference vector v_ref = np.array([cx, cy]) - ref_pt # choose the one from v_d and -v_d whose angle with v_ref # is less than pi if np.dot(v_d, v_ref) < 0: v_d = -v_d # counterclockwise angle from (0,-1) to v_d angle = np.arctan2(-v_d[1], v_d[0]) + np.pi/2 if angle < 0: angle += 2*np.pi arrow = ((int(cx), int(cy)), (int(cx+v_d[0]), int(cy+v_d[1]))) return angle, box, arrow ''' Obtain the angles and stretch of elastic double pendulum from image Args: img: elastic double pendulum image in BGR format Returns: rej: (True/False) if the image is rejected angles: the estimated angles in radians in range (0, 2*pi) stretch: the estimated stretch of the elastic pendulum (unit: m) img_marked: image marked with the fitted rectangles, the direction vectors and the estimated angles and stretch (BGR format) ''' def obtain_angle_stretch(img): img_marked = img.copy() seg_1, seg_2 = seg_from_img(img) rej_1, rect_1 = fit_pendulum(seg_1) rej_2, rect_2 = fit_pendulum(seg_2) if rej_1: rej = True angles = [] stretch = None cv2.putText(img_marked, 'Reject arm 1', (10, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1) elif rej_2: rej = True angles = [] stretch = None cv2.putText(img_marked, 'Reject arm 2', (10, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1) else: rej = False img_center = np.array([64, 64]) angle_1, box_1, arrow_1 = estimate_angle(rect_1, img_center) (cx, cy), (width, height), _ = rect_1 pd_1_end = 2 * np.array([cx, cy]) - img_center angle_2, box_2, arrow_2 = estimate_angle(rect_2, pd_1_end) angles = [angle_1, angle_2] # compute the stretch of the elastic pendulum (unit: m) # max(width, height): stretched length of the elastic pendulum in pixels L_0_img = 24 # unstretched length of the elastic pendulum in pixels L_0 = 0.205 # unstretched length of the elastic pendulum (m) stretch = (max(width, height) - L_0_img) / L_0_img * L_0 # mark the fitted rectangles cv2.drawContours(img_marked, [np.int0(box_1)], 0, (0,0,255), 2) cv2.drawContours(img_marked, [np.int0(box_2)], 0, (0,0,255), 2) # mark the direction vectors cv2.arrowedLine(img_marked, arrow_1[0], arrow_1[1], (0, 0, 255), 1, tipLength=0.25) cv2.arrowedLine(img_marked, arrow_2[0], arrow_2[1], (0, 0, 255), 1, tipLength=0.25) # mark the estimated angles (degrees) and stretch (m) text = str(round(angle_1*180/np.pi)) + ' ' + str(round(angle_2*180/np.pi)) + ' ' + ('%.2f' % stretch) cv2.putText(img_marked, text, (10, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.45, (0, 0, 0), 1) return rej, angles, stretch, img_marked ================================================ FILE: analysis/eval_phys_elastic_pendulum/physics_estimator.py ================================================ ''' This script provides utility functions estimating velocities and other physical quantities of elastic double pendulum. ''' import numpy as np from scipy.interpolate import CubicSpline # physical parameters fps = 60 # frames per second L_0 = 0.205 # elastic pendulum unstretched length (m) L = 0.179 # rigid pendulum rod length (m) w = 0.038 # rigid pendulum rod width (m) m = 0.110 # rigid pendulum mass (kg) g = 9.81 # gravitational acceleration (m/s^2) k = 40.0 # elastic constant (kg/s^2) ''' Calculate the absolute difference between two angles th1 and th2 on a circle. Assumed that the absolute difference between the two angles is within range (0,pi). ''' def calc_diff(th1, th2): diff = np.abs(th2 - th1) diff = np.minimum(diff, 2*np.pi-diff) return diff ''' Calculate the average of two angles th1 and th2 on a circle. Assumed that the absolute difference between the two angles is within range (0,pi). ''' def calc_avrg(th1, th2): avrg = (th1 + th2) / 2 diff = np.abs(th2 - th1) if diff > np.pi: avrg -= np.pi if avrg < 0: avrg += 2*np.pi return avrg ''' Normalize the angle to (-pi, pi). ''' def normalize_angle(theta): return np.arctan2(np.sin(theta), np.cos(theta)) ''' Calculate velocities from a sequence of data using numerical differentiation. method='fd': finite difference; method='spline': cubic spline fitting. ''' def calc_velocity(x, method='spline', periodic=True): len_seq = x.shape[0] # isolated data if len_seq == 1: return np.nan # preprocessing: periodic extension of angles if periodic: for i in range(1, len_seq): if x[i] - x[i-1] > np.pi: x[i:] -= 2*np.pi elif x[i] - x[i-1] < -np.pi: x[i:] += 2*np.pi vel_x = np.zeros(len_seq) # finite difference if method == 'fd': for i in range(1, len_seq): vel_x[i] = (x[i] - x[i-1]) * fps vel_x[0] = (x[1] - x[0]) * fps # cubic spline fitting elif method == 'spline': t = np.arange(len_seq) / fps cs = CubicSpline(t, x) vel_x = cs(t, 1) # use finite difference at boundary points to improve accuracy vel_x[0] = (x[1] - x[0]) * fps vel_x[-1] = (x[-1] - x[-2]) * fps else: assert False, 'Unrecognizable differentiation method!' return vel_x ''' Calculate energies ''' def calc_energy(th1, th2, z, vel_th1, vel_th2, vel_z): # center of mass in x-y coordinates x = (L_0 + z) * np.sin(th1) + (L / 2) * np.sin(th2) y = -(L_0 + z) * np.cos(th1) - (L / 2) * np.cos(th2) # velocities in x-y coordinates vel_x = vel_th1 * (L_0 + z) * np.cos(th1) + vel_th2 * (L / 2) * np.cos(th2) + vel_z * np.sin(th1) vel_y = vel_th1 * (L_0 + z) * np.sin(th1) + vel_th2 * (L / 2) * np.sin(th2) - vel_z * np.cos(th1) # moment of inertia I = (1. / 12.) * m * (w ** 2 + L ** 2) # potential energy (gravitational + elastic) V = m * g * y + 0.5 * k * z**2 # kinetic energy (translation + rotation) T = 0.5 * m * (vel_x ** 2 + vel_y ** 2) + 0.5 * I * vel_th2 ** 2 # total energy E = T + V return T, V, E ================================================ FILE: analysis/eval_phys_long_term_pred.py ================================================ """ Evaluate long-term prediction accuracy in physical variables. """ import os import sys import numpy as np from scipy import stats import cv2 import json import yaml import pprint from tqdm import tqdm from munch import munchify ''' Calculate the absolute difference between two angles th1 and th2 on a circle. Assumed that the absolute difference between the two angles is within range (0,pi). ''' def calc_diff(th1, th2): diff = np.abs(th2 - th1) diff = np.minimum(diff, 2*np.pi-diff) return diff ''' Calculate the pixel mean square error between two images. ''' def calc_pixel_MSE(img1, img2): return np.mean((img1/255.0 - img2/255.0) ** 2) def load_config(filepath): with open(filepath, 'r') as stream: try: trainer_params = yaml.safe_load(stream) return trainer_params except yaml.YAMLError as exc: print(exc) def eval_pred_physics(data_filepath, pred_save_path, vid_ids, phys_vars_list, eval_physics): phys = {p_var:[] for p_var in phys_vars_list} for idx in tqdm(vid_ids): data_vid_filepath = os.path.join(data_filepath, str(idx)) pred_vid_filepath = os.path.join(pred_save_path, str(idx)) num_frames = len(os.listdir(data_vid_filepath)) num_pred_frames = len(os.listdir(pred_vid_filepath)) frames = [] for p in range(num_frames): if (num_pred_frames == num_frames - 2) and (p == 0 or p == 1): # read the initial two frames from ground truth data frame_p = cv2.imread(os.path.join(data_vid_filepath, str(p)+'.png')) else: frame_p = cv2.imread(os.path.join(pred_vid_filepath, str(p)+'.png')) frames.append(frame_p) phys_tmp = eval_physics(frames) for p_var in phys_vars_list: phys[p_var].append(phys_tmp[p_var]) for p_var in phys_vars_list: phys[p_var] = np.array(phys[p_var]) return phys def load_data_physics(data_filepath, vid_ids, phys_vars_list): phys = np.load(os.path.join(data_filepath, 'phys_vars.npy'), allow_pickle=True).item() for p_var in phys_vars_list: phys[p_var] = phys[p_var][vid_ids] return phys def eval_physics_error(phys_pred, phys_data, phys_vars_list): phys_vars_list_2 = [p_var for p_var in phys_vars_list if p_var!='reject'] phys_error = {} phys_error['reject'] = phys_pred['reject'].copy() if 'reject' in phys_data.keys(): phys_error['reject_data'] = phys_data['reject'].copy() else: phys_error['reject_data'] = np.zeros(phys_pred['reject'].shape) for p_var in phys_vars_list_2: if p_var in ['theta', 'theta_1', 'theta_2']: phys_error[p_var] = calc_diff(phys_pred[p_var], phys_data[p_var]) else: phys_error[p_var] = np.abs(phys_pred[p_var] - phys_data[p_var]) return phys_error def eval_pixel_error(data_filepath, pred_save_path, vid_ids): pixel_error = [] for idx in tqdm(vid_ids): data_vid_filepath = os.path.join(data_filepath, str(idx)) pred_vid_filepath = os.path.join(pred_save_path, str(idx)) num_frames = len(os.listdir(data_vid_filepath)) pixel_error_idx = [] for p in range(num_frames): if p == 0 or p == 1: pixel_error_idx.append(0) else: data = cv2.imread(os.path.join(data_vid_filepath, str(p)+'.png')) pred = cv2.imread(os.path.join(pred_vid_filepath, str(p)+'.png')) pixel_error_idx.append(calc_pixel_MSE(data, pred)) pixel_error.append(pixel_error_idx) return np.array(pixel_error) def main(): config_filepath = str(sys.argv[1]) pred_save_path = str(sys.argv[2]) cfg = load_config(filepath=config_filepath) pprint.pprint(cfg) cfg = munchify(cfg) data_filepath = os.path.join(cfg.data_filepath, cfg.dataset) with open(os.path.join('../datainfo', cfg.dataset, f'data_split_dict_{cfg.seed}.json'), 'r') as file: seq_dict = json.load(file) vid_ids = seq_dict['test'] if cfg.dataset == 'single_pendulum': from eval_phys_single_pendulum import phys_vars_list, eval_physics elif cfg.dataset == 'double_pendulum': from eval_phys_double_pendulum import phys_vars_list, eval_physics elif cfg.dataset == 'elastic_pendulum': from eval_phys_elastic_pendulum import phys_vars_list, eval_physics else: assert False, 'Unknown system...' phys_pred = eval_pred_physics(data_filepath, pred_save_path, vid_ids, phys_vars_list, eval_physics) phys_data = load_data_physics(data_filepath, vid_ids, phys_vars_list) phys_error = eval_physics_error(phys_pred, phys_data, phys_vars_list) pixel_error = eval_pixel_error(data_filepath, pred_save_path, vid_ids) np.save(os.path.join(pred_save_path, 'phys_vars.npy'), phys_pred) np.save(os.path.join(pred_save_path, 'phys_error.npy'), phys_error) np.save(os.path.join(pred_save_path, 'pixel_error.npy'), pixel_error) if __name__ == '__main__': main() ================================================ FILE: analysis/eval_phys_single_pendulum/__init__.py ================================================ from .angle_estimator import obtain_angle from .physics_estimator import * import numpy as np phys_vars_list = {'reject', 'theta', 'vel_theta', 'kinetic energy', 'potential energy', 'total energy'} def eval_physics(frames): num_frames = len(frames) reject = np.zeros(num_frames, dtype=bool) theta = np.zeros(num_frames) # estimate angles for p in range(num_frames): reject_p, theta_p, _ = obtain_angle(frames[p]) if reject_p: reject[p] = True theta[p] = np.nan else: reject[p] = False theta[p] = theta_p # calculate velocities vel_theta = np.zeros(num_frames) sub_ids = np.ma.clump_unmasked(np.ma.masked_array(theta, reject)) for ids in sub_ids: vel_theta[ids] = calc_velocity(theta[ids].copy()) vel_theta[reject] = np.nan # calculate energies kinetic_energy, potential_energy, total_energy = calc_energy(theta, vel_theta) # save results phys = dict.fromkeys(phys_vars_list) phys['reject'] = reject phys['theta'] = theta phys['vel_theta'] = vel_theta phys['kinetic energy'] = kinetic_energy phys['potential energy'] = potential_energy phys['total energy'] = total_energy return phys ================================================ FILE: analysis/eval_phys_single_pendulum/angle_estimator.py ================================================ ''' This script provides utility functions estimating the angle of a single pendulum from image. Step 1. Extract the pendulum from the image. Step 2. Do a rectangle fitting of the pendulum. Step 3. Estimate the angle of the pendulum. Certain images will be rejected if the pendulum does not exist or has a wrong shape. ''' import cv2 import os import numpy as np ''' Extract the pendulum from the image. Args: img: single pendulum image in BGR format Returns: seg: segmentation of the pendulum ''' def seg_from_img(img): # pixel thresholds (in HSV) v_min = (60, 0, 0) v_max = (255, 255, 255) # extract the pendulum img_hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV) seg = cv2.inRange(img_hsv, v_min, v_max) return seg ''' Fit pendulum in a rectangle. Args: seg: segmentation of the pendulum Returns: rej: (True/False) if the image is rejected rect: (Box2D structure) the fitted rectangle ''' def fit_pendulum(seg): # find all contours contours, _ = cv2.findContours(seg, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) # reject if no contours found if len(contours) == 0: return True, None # find the contour with the maximum area cnt = max(contours, key=cv2.contourArea) area = cv2.contourArea(cnt) # reject if the contour is too small if area < 400: return True, None # rectangle fitting rect = cv2.minAreaRect(cnt) _, (width, height), _ = rect # reject if the rectangle is too close to a square if abs(height-width) < 35: return True, None # reject if the rectangle does not properly fit the contour if height*width > 1.5 * area: return True, None return False, rect ''' Estimate the angle of pendulum from its fitted rectangle. Args: rect: (Box2D structure) the fitted rectangle Returns: angle: the estimated angle in radians in range (0, 2*pi) box: box points of the rectangle arrow: the arrow pointing along the angle ''' def estimate_angle(rect): # box points box = cv2.boxPoints(rect) # center, width and height (cx, cy), (width, height), _ = rect # specify the direction vector of the pendulum if width < height: v_d = box[0] - box[1] else: v_d = box[0] - box[3] # reference vector v_ref = np.array([cx-64, cy-64]) # choose the one from v_d and -v_d whose angle with v_ref # is less than pi if np.dot(v_d, v_ref) < 0: v_d = -v_d # counterclockwise angle from (0,-1) to v_d angle = np.arctan2(-v_d[1], v_d[0]) + np.pi/2 if angle < 0: angle += 2*np.pi arrow = ((int(cx), int(cy)), (int(cx+0.7*v_d[0]), int(cy+0.7*v_d[1]))) return angle, box, arrow ''' Obtain the angle of single pendulum from image Args: img: single pendulum image in BGR format Returns: rej: (True/False) if the image is rejected angle: the estimated angle in radians in range (0, 2*pi) img_marked: image marked with the fitted rectangle, the direction vector and the estimated angle (BGR format) ''' def obtain_angle(img): img_marked = img.copy() seg = seg_from_img(img) rej, rect = fit_pendulum(seg) if not rej: angle, box, arrow = estimate_angle(rect) # mark the fitted rectangle cv2.drawContours(img_marked, [np.int0(box)], 0, (0,0,255), 2) # mark the direction vector cv2.arrowedLine(img_marked, arrow[0], arrow[1], (0, 0, 255), 1, tipLength=0.25) # mark the estimated angle in degrees cv2.putText(img_marked, str(round(angle*180/np.pi)), (10, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1) else: # mark the rejection cv2.putText(img_marked, 'Reject', (10, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1) angle = np.nan return rej, angle, img_marked ================================================ FILE: analysis/eval_phys_single_pendulum/physics_estimator.py ================================================ ''' This script provides utility functions estimating angular velocities and other physical quantities from angles of single pendulum. ''' import numpy as np from scipy.interpolate import CubicSpline # physical parameters fps = 60 # frames per second l = 0.5 # pendulum rod length (m) m = 1.0 # bob mass (kg) g = 9.81 # gravitational acceleration (m/s^2) ''' Calculate the absolute difference between two angles th1 and th2 on a circle. Assumed that the absolute difference between the two angles is within range (0,pi). ''' def calc_diff(th1, th2): diff = np.abs(th2 - th1) diff = np.minimum(diff, 2*np.pi-diff) return diff ''' Calculate the average of two angles th1 and th2 on a circle. Assumed that the absolute difference between the two angles is within range (0,pi). ''' def calc_avrg(th1, th2): avrg = (th1 + th2) / 2 diff = np.abs(th2 - th1) if diff > np.pi: avrg -= np.pi if avrg < 0: avrg += 2*np.pi return avrg ''' Normalize the angle to (-pi, pi). ''' def normalize_angle(theta): return np.arctan2(np.sin(theta), np.cos(theta)) ''' Calculate angular velocities from a sequence of angles using numerical differentiation. method='fd': finite difference; method='spline': cubic spline fitting. ''' def calc_velocity(th, method='spline'): len_seq = th.shape[0] # isolated data if len_seq == 1: return np.nan # preprocessing: periodic extension of angles for i in range(1, len_seq): if th[i] - th[i-1] > np.pi: th[i:] -= 2*np.pi elif th[i] - th[i-1] < -np.pi: th[i:] += 2*np.pi vel_th = np.zeros(len_seq) # finite difference if method == 'fd': for i in range(1, len_seq): vel_th[i] = (th[i] - th[i-1]) * fps vel_th[0] = (th[1] - th[0]) * fps # cubic spline fitting elif method == 'spline': t = np.arange(len_seq) / fps cs = CubicSpline(t, th) vel_th = cs(t, 1) # use finite difference at boundary points to improve accuracy vel_th[0] = (th[1] - th[0]) * fps vel_th[-1] = (th[-1] - th[-2]) * fps else: assert False, 'Unrecognizable differentiation method!' return vel_th ''' Calculate energies from angles and angular velocities Moment of inertia of the pendulum: I=ml^2/3 ''' def calc_energy(th, vel_th): T = m * l**2 / 6 * vel_th**2 V = - m * g * l / 2 * np.cos(th) E = T + V return T, V, E ================================================ FILE: analysis/eval_regression.py ================================================ import os import sys import numpy as np from tqdm import tqdm import json import yaml import pprint from munch import munchify from latent_regression import pca, mlp_regress def load_config(filepath): with open(filepath, 'r') as stream: try: trainer_params = yaml.safe_load(stream) return trainer_params except yaml.YAMLError as exc: print(exc) # parse the video no. and frame no. from the data id (example : '7_0.png') def parse_data_id(data_id): lhs, rhs = data_id.split('_') vid_n = int(lhs) frm_n = int(rhs.split('.')[0]) return vid_n, frm_n def physical_variables_from_data_ids(phys_all, ids): phys_vars_list = [] for p_var in phys_all.keys(): if p_var == 'reject': continue for t in range(4): phys_vars_list.append(f'{p_var} (t={t})') num_data = ids.shape[0] phys = {p_var:np.zeros(num_data) for p_var in phys_vars_list} for n in range(num_data): vid_n, frm_n = parse_data_id(ids[n]) for p_var in phys_all.keys(): if p_var == 'reject': continue for t in range(4): phys[f'{p_var} (t={t})'][n] = phys_all[p_var][vid_n, frm_n+t] return phys def main(): config_filepath = str(sys.argv[1]) cfg = load_config(filepath=config_filepath) pprint.pprint(cfg) cfg = munchify(cfg) if_refine = ('refine' in cfg.model_name) phys_all = np.load(os.path.join(cfg.data_filepath, cfg.dataset, 'phys_vars.npy'), allow_pickle=True).item() log_dir = '_'.join([cfg.log_dir, cfg.dataset, cfg.model_name, str(cfg.seed)]) # latent vectors and physical variables (train and test) ids_train = np.load(os.path.join(log_dir, 'variables_train', 'ids.npy')) ids_test = np.load(os.path.join(log_dir, 'variables', 'ids.npy')) if if_refine: latent_train = np.load(os.path.join(log_dir, 'variables_train', 'refine_latent.npy')) latent_test = np.load(os.path.join(log_dir, 'variables', 'refine_latent.npy')) else: latent_train = np.load(os.path.join(log_dir, 'variables_train', 'latent.npy')) latent_test = np.load(os.path.join(log_dir, 'variables', 'latent.npy')) phys_train = physical_variables_from_data_ids(phys_all, ids_train) phys_test = physical_variables_from_data_ids(phys_all, ids_test) phys_vars_list = phys_train.keys() # remove outliers is_outlier_train = np.full(latent_train.shape[0], False) for p_var in phys_vars_list: is_outlier_train = is_outlier_train | np.isnan(phys_train[p_var]) for p_var in phys_vars_list: phys_train[p_var] = phys_train[p_var][~is_outlier_train] latent_train = latent_train[~is_outlier_train] is_outlier_test = np.full(latent_test.shape[0], False) for p_var in phys_vars_list: is_outlier_test = is_outlier_test | np.isnan(phys_test[p_var]) for p_var in phys_vars_list: phys_test[p_var] = phys_test[p_var][~is_outlier_test] latent_test = latent_test[~is_outlier_test] print(f'Valid training samples: {latent_train.shape[0]}; test samples: {latent_test.shape[0]}.') # only use first few principal components (optional) if str(sys.argv[2]) == 'NA': num_components = None else: num_components = int(sys.argv[2]) latent_train, latent_test, pca_model = pca(latent_train, latent_test, num_components, cfg.seed) num_samples = 3 sample_size = int(0.3 * latent_train.shape[0]) train_error = {p_var:np.zeros(num_samples) for p_var in phys_vars_list} test_error = {p_var:np.zeros(num_samples) for p_var in phys_vars_list} reg_models = {p_var:np.empty(num_samples, dtype=object) for p_var in phys_vars_list} rng = np.random.default_rng(cfg.seed) for i in range(num_samples): print('random samples #'+str(i+1)) samp_ids = rng.choice(latent_train.shape[0], sample_size, replace=False) for p_var in phys_vars_list: print('doing '+p_var+' regression...') train_error[p_var][i], test_error[p_var][i], reg_models[p_var][i] = mlp_regress(latent_train[samp_ids], latent_test, phys_train[p_var][samp_ids], phys_test[p_var], cfg.seed) if num_components == None: np.save(os.path.join(log_dir, 'variables', 'regression_results.npy'), [train_error, test_error, reg_models]) else: np.save(os.path.join(log_dir, 'variables', 'regression_results_pca_'+str(num_components)+'.npy'), [train_error, test_error, pca_model, reg_models]) if __name__ =='__main__': main() ================================================ FILE: analysis/intrinsic_dimension_estimation/__init__.py ================================================ import numpy as np from .methods import * class ID_Estimator: def __init__(self, method='Levina_Bickel'): self.all_methods = ['Levina_Bickel', 'MiND_ML', 'MiND_KL', 'Hein', 'CD'] self.set_method(method) def set_method(self, method='Levina_Bickel'): if method not in self.all_methods: assert False, 'Unknown method!' else: self.method = method def fit(self, X, k_list=20, n_jobs=4): if self.method in ['Hein', 'CD']: dim_Hein, dim_CD = Hein_CD(X) return dim_Hein if self.method=='Hein' else dim_CD else: if np.isscalar(k_list): k_list = np.array([k_list]) else: k_list = np.array(k_list) kmax = np.max(k_list) + 2 dists, inds = kNN(X, kmax, n_jobs) dims = [] for k in k_list: if self.method == 'Levina_Bickel': dims.append(Levina_Bickel(X, dists, k)) elif self.method == 'MiND_ML': dims.append(MiND_ML(X, dists, k)) elif self.method == 'MiND_KL': dims.append(MiND_KL(X, dists, k)) else: pass if len(dims) == 1: return dims[0] else: return np.array(dims) def fit_all_methods(self, X, k_list=[20], n_jobs=4): k_list = np.array(k_list) kmax = np.max(k_list) + 2 dists, inds = kNN(X, kmax, n_jobs) dim_all_methods = {method:[] for method in self.all_methods} dim_all_methods['Hein'], dim_all_methods['CD'] = Hein_CD(X) for k in k_list: dim_all_methods['Levina_Bickel'].append(Levina_Bickel(X, dists, k)) dim_all_methods['MiND_ML'].append(MiND_ML(X, dists, k)) dim_all_methods['MiND_KL'].append(MiND_KL(X, dists, k)) for method in self.all_methods: dim_all_methods[method] = np.array(dim_all_methods[method]) return dim_all_methods ================================================ FILE: analysis/intrinsic_dimension_estimation/matlab_codes/DANCo.m ================================================ function [d,kl] = DANCo(data,varargin) % DANCO Estmating the intrinsic dimensionality of a dataset % % function [d,kl] = DANCo(data, ParamName,ParamValue, ...) % % This funciton estimates the intrinsic dimensionality of a given dataset % by means of the DANCo algorithm described in: % % "DANCo: Dimensionality from Angle and Norm Concentration" % C. Ceruti, S. Bassis, A. Rozza, G. Lombardi, E. Casiraghi, P. Campadelli % arXiv:1206.3881v1, http://arxiv.org/abs/1206.3881v1 % % with the complexity improvement given by separating the training step % from the evaluation one (requires fitting models in the file DANCo_fits). % % Parameters % ---------- % IN: % data = The data matrix (DxN) with a sample per column. % Named parameters: % 'k' % The number of neighbours to be used, must be >= 1. (def=10) % 'minDim' % The minimal number of dimensions. (def=1) % 'maxDim' % The maximal number of dimensions. (def=size(data,1)) % 'sigma' % Standard deviation of the gaussian smoothing filter executed on % the estimated divergences. (def=[]=no smoothing) % 'fractal' % Is the fractal dimension required? (def=false) % 'mlMethod' % Algorithm used to pre-estimate d: % MLE = Levina's algorithm. % MiND_MLI = ML on the first neighbour. % MiND_MLK = ML on the first neighbour and optimization. % (def='MiND_MLK') % 'minCorrectionDim' % Minimal dimensionality estimated with the selected mlMethod such % that DANCO is adopted for correction, for lower dimensionalities % the mlMethod result is produced. (def=5) % 'inds' % Precomputed knn indices as returned by the KNN function for k % set to k+2 with respect to the given k parameter. (also dists must % be present) % 'dists' % Precomputed knn distances as returned by the KNN function for k % set to k+2 with respect to the given k parameter. (also inds must % be present) % 'normalized' % Are the inds/dists already normalized? Normalized here means that % all the k+2 distances are normalized wrt the last one and the first % (zero) and the last (one) are removed. (def=true) % OUT: % d = Estimated intrinsic dimensionality of the dataset. % kl = The Kullback Leibler divergences curve for d=1..D. % % Examples % -------- % X = randsphere(30,1000,1); % V = linSubspSpanOrthonormalize(randn(120,30)); % pts = V*X; % [inds,dists] = KNN(pts,12,true); % [d,kl] = DANCo(pts,'inds',inds,'dists',dists,'minDim',20,'maxDim',40); % Checking parameters: if nargin<1; error('A dataset is required'); end % Infos: [D,N] = size(data); % Default parameters: params = struct('k',{10}, ... 'minDim',{1}, 'maxDim',{D}, ... 'fractal',{false}, ... 'sigma',{[]},... 'mlMethod',{'MiND_MLK'}, ... 'normalized',{true}, ... 'minCorrectionDim',{5}); % Parsing named parameters: if nargin > 1 try params = parseParamsNamed(varargin, false, params); catch e error('Pairs ''name,value'' are required as arguments after the dataset.'); end end % Estimating the parameters for the real data: [isLowDim,dHat,mu,tau] = DANCo_statistics(data,params.k,params,nargout > 1); % Correction based on the KL-divergence: if not(isLowDim) || nargout > 1 % Initializing the parameters: params2 = struct( ... 'mlMethod',{params.mlMethod}, ... 'minCorrectionDim',{params.minCorrectionDim}); numDims = params.maxDim - params.minDim + 1; % Computing the reference parameters: dHat_ref = zeros(1,numDims); mu_ref = zeros(1,numDims); tau_ref = zeros(1,numDims); for i = 1:numDims % Generating the dataset: data2 = randsphere(params.minDim + i - 1, N); % Estimating the parameters: [~,dHat_ref(i),mu_ref(i),tau_ref(i)] = ... DANCo_statistics(data2,params.k,params2,true); end % Estimating the KL-divergence for all the cases d=1:D: kl = DANCo_estimateKL(params.k,dHat,mu,tau,dHat_ref,mu_ref,tau_ref); end % Smoothing if required: if (not(isLowDim) || nargout > 1) && not(isempty(params.sigma)) % Preparing the impulse response: rad = 4*ceil(params.sigma/2); h = gauss(-rad:rad,params.sigma,0,false); h = h/sum(h); % Filtering: kl = conv([flipdim(kl,2),kl,kl],h,'same'); kl = kl(numDims+1:end-numDims); end % The dimensionality estimation: if isLowDim % Using the MLE estimation: if params.fractal; d = dHat; else d = round(dHat); end else % Selecting the minimum kl position: [~,d] = min(kl); % Refining for the fractal dimension: if params.fractal % Fitting with a cubic smoothing spline: fn = csapi(1:numDims,kl); % Locating the minima: opts = optimset('Display','off','TolX',1e-3); d = fminsearch(@(x)(fnval(fn,x)),d,opts); end end % Correcting the dimension considering the minimal one: d = d + params.minDim - 1; end ================================================ FILE: analysis/intrinsic_dimension_estimation/matlab_codes/DANCoFit.m ================================================ function [d,kl] = DANCoFit(data,varargin) % DANCOFIT Estmating the intrinsic dimensionality of a dataset % % function [d,kl] = DANCoFit(data, ParamName,ParamValue, ...) % % This funciton estimates the intrinsic dimensionality of a given dataset % by means of the fast DANCo algorithm described in: % % "DANCo: Dimensionality from Angle and Norm Concentration" % C. Ceruti, S. Bassis, A. Rozza, G. Lombardi, E. Casiraghi, P. Campadelli % arXiv:1206.3881v1, http://arxiv.org/abs/1206.3881v1 % % with the complexity improvement given by separating the training step % from the evaluation one (requires fitting models in the file DANCo_fits). % % Parameters % ---------- % IN: % data = The data matrix (DxN) with a sample per column. % Named parameters: % 'fractal' % Is the fractal dimension required? (def=false) % 'mlMethod' % Algorithm used to pre-estimate d: % MLE = Levina's algorithm. % MiND_MLI = ML on the first neighbour. % MiND_MLK = ML on the first neighbour and optimization. % (def='MiND_MLK') % 'minCorrectionDim' % Minimal dimensionality estimated with the selected mlMethod such % that DANCO is adopted for correction, for lower dimensionalities % the mlMethod result is produced. (def=5) % 'inds' % Precomputed knn indices as returned by the KNN function for k % set to k+2 with respect to the given k parameter. (also dists must % be present) % 'dists' % Precomputed knn distances as returned by the KNN function for k % set to k+2 with respect to the given k parameter. (also inds must % be present) % 'normalized' % Are the inds/dists already normalized? Normalized here means that % all the k+2 distances are normalized wrt the last one and the first % (zero) and the last (one) are removed. (def=true) % 'modelfile' % The name of the model file to be used. (def='DANCo_fits') % OUT: % d = Estimated intrinsic dimensionality of the dataset. % kl = The Kullback Leibler divergences curve for d=1..D. % % Examples % -------- % X = randsphere(30,1000,1); % V = linSubspSpanOrthonormalize(randn(120,30)); % pts = V*X; % [inds,dists] = KNN(pts,12,true); % [d,kl] = DANCoFit(pts,'inds',inds,'dists',dists); % Checking parameters: if nargin<1; error('A dataset is required'); end % Infos: [D,N] = size(data); d = 1:D; % Default parameters: params = struct('fractal',{false}, ... 'mlMethod',{'MiND_MLK'}, ... 'normalized',{true}, ... 'minCorrectionDim',{5}, ... 'modelfile',{'DANCo_fits'}); % Parsing named parameters: if nargin > 1 try params = parseParamsNamed(varargin, false, params); catch e error('Pairs ''name,value'' are required as arguments after the dataset.'); end end % Loading the model: load(params.modelfile); % Estimating the parameters: [isLowDim,dHat,mu,tau] = DANCo_statistics(data,k,params,nargout > 1); % Correction based on the KL-divergence: if not(isLowDim) || nargout > 1 % Computing the reference parameters: dHat_ref = feval(fitDhat,d,N); mu_ref = feval(fitMu,d,N); tau_ref = feval(fitTau,d,N); % Estimating the KL-divergence for all the cases d=1:D: kl = DANCo_estimateKL(k,dHat,mu,tau,dHat_ref,mu_ref,tau_ref); end % The dimensionality estimation: if isLowDim % Using the MLE estimation: if params.fractal; d = dHat; else d = round(dHat); end else % Selecting the minimum kl position: [~,d] = min(kl); % Refining for the fractal dimension: if params.fractal % Fitting with a cubic smoothing spline: fn = csapi(1:D,kl); % Locating the minima: opts = optimset('Display','off','TolX',1e-3); d = fminsearch(@(x)(fnval(fn,x)),d,opts); end end end ================================================ FILE: analysis/intrinsic_dimension_estimation/matlab_codes/DANCoTrain.m ================================================ function DANCoTrain(modelfile,varargin) % DANCoTrain Training the DANCoFit model file. % % function DANCoTrain(modelfile, ParamName,ParamValue, ...) % % This funciton executes experiments with points uniformly drawn from % hyper-spheres with unitary radius to produce the fitting functions used % by DANCoFit, if you use this please cite: % % "DANCo: Dimensionality from Angle and Norm Concentration" % C. Ceruti, S. Bassis, A. Rozza, G. Lombardi, E. Casiraghi, P. Campadelli % arXiv:1206.3881v1, http://arxiv.org/abs/1206.3881v1 % % Parameters % ---------- % IN: % modelfile = The name of the model file to be produced. % Named parameters: % 'k' % The number of neighbours to be used, must be >= 1. (def=10) % 'maxDim' % The maximal dimensionality to be used in the experiments, more % dims means more computation but also accurate fits. (def=100) % 'cardinalities' % The list of cardinalities to be used in the experiments, more % cardinalities means more computation but also accurate fits, the % suggested choice is to adopt an 1-2-5 series. % (def=[50,100,200,500,1000,2000,5000]) % 'iterations' % The number of iterations for experiment (results averaged), more % iterations means more computation but also accurate fits. (def=100) % 'mlMethod' % Algorithm used to pre-estimate d: % MLE = Levina's algorithm. % MiND_MLI = ML on the first neighbour. % MiND_MLK = ML on the first neighbour and optimization. % (def='MiND_MLK') % % Example % ------- % DANCoTrain('sampleModel','maxDim',20,'cardinalities',[200,500,1000],'iterations',1); % X = randsphere(10,500); % V = linSubspSpanOrthonormalize(randn(20,10)); % pts = V*X; % DANCoFit(pts,'modelfile','sampleModel') % Checking parameters: if nargin<1; error('A model file name is required.'); end % Default parameters: params = struct('k',{10}, ... 'maxDim',{100}, ... 'cardinalities',{[50,100,200,500,1000,2000,5000]}, ... 'iterations',{100}, ... 'mlMethod',{'MiND_MLK'}); % Parsing named parameters: if nargin > 1 try params = parseParamsNamed(varargin, false, params); catch e error('Pairs ''name,value'' are required as arguments after the dataset.'); end end % Parameters to be passed to the statistics computation function: params2 = struct( ... 'mlMethod',{params.mlMethod}, ... 'minCorrectionDim',{1}); k = params.k; % Computing the size of the statistic matrices: N = numel(params.cardinalities); D = params.maxDim; % Initialization: dHat = zeros(D-1,N); mu = zeros(D-1,N); tau = zeros(D-1,N); for d = 2:D for n = 1:N % Initializing the data of this iteration: dHat_ref = zeros(1,params.iterations); mu_ref = zeros(1,params.iterations); tau_ref = zeros(1,params.iterations); % Executing the experiments: for i = 1:params.iterations % Generating the dataset: data = randsphere(d, params.cardinalities(n)); % Sayng to the user what I'm doing: fprintf('(d = %d \tn = %d \ti = %d)...', ... d,params.cardinalities(n),i); % Estimating the parameters: [~,dHat_ref(i),mu_ref(i),tau_ref(i)] = ... DANCo_statistics(data,k,params2,true); % Sayng to the user what I'm doing: fprintf(' \tDONE!\n'); end % Storing the results: dHat(d-1,n) = mean(dHat_ref); mu(d-1,n) = mean(mu_ref); tau(d-1,n) = mean(tau_ref); end end % Sayng to the user what I'm doing: fprintf('Fitting & saving...'); % Initializing: inVars = {2:D,params.cardinalities}; % Fitting the funcitons: fitDhatFun = fnxtr(csaps(inVars,dHat),2); fitMuFun = fnxtr(csaps(inVars,mu),2); fitTauFun = fnxtr(csaps(inVars,tau),2); % Producing function handles: fitDhat = @(D,N)(fnval(fitDhatFun,{D,N})'); fitMu = @(D,N)(fnval(fitMuFun,{D,N})'); fitTau = @(D,N)(fnval(fitTauFun,{D,N})'); % Writing the model: save(modelfile,'fitMu','fitTau','fitDhat','k'); % Sayng to the user what I'm doing: fprintf(' \tDONE!\n'); end ================================================ FILE: analysis/intrinsic_dimension_estimation/matlab_codes/KNN.m ================================================ function [inds,dists] = KNN(X,k,normalized) %KNN A multi-version knnsearch % % function [inds,dists] = KNN(X,k) % % This function allows to identify the indexes (and optionally the % distances) of the first k neighbours of each given columnwise point. % % IN: % X = The dataset, each column represents a point. % k = The number of neighbours of interest. (def=10) % normalized = Must the k neighbours be normalized? % (triggers a k+2 search, def=false) % OUT: % inds = Each row i-th contains the neighbour indexes of the i-th point. % dists = Each row i-th contains the neighbour distances of the i-th point. % Checking parameters: if nargin<1; error('Dataset required.'); end if nargin<2; k=10; end if nargin<3; normalized=false; end % Managing th normalized: if normalized; k=k+2; end % Checking the version: if datenum(version('-date')) < 734729 % Computing all the distances: dists = L2(X, X); % Sorting: [dists, inds] = sort(dists, 2); % Choosing the first k elements: dists = dists(:,1:k); inds = inds(:,1:k); else % The Matlab implementation: X = X'; [inds, dists] = knnsearch(X,X,'K',k); end % Managing th normalized: if normalized % Cropping both inds and dists: inds = inds(:,2:end-1); dists = dists(:,2:end); % Normalizing and removing the last element: dists = dists(:,1:end-1)./repmat(dists(:,end),[1,size(dists,2)-1]); end end % ------------------------ LOCAL FUNCTIONS ------------------------ % L2 distance (by Roland Bunschoten): function d = L2(a,b) d = real(sqrt(bsxfun(@plus, sum(a .* a)', bsxfun(@minus, sum(b .* b), 2 * a' * b)))); end ================================================ FILE: analysis/intrinsic_dimension_estimation/matlab_codes/MLE.m ================================================ function [d,ds] = MLE(data,varargin) %MLE Intrinsic dimensionality using the Levina's MLE technique % % function [d,ds] = MLE(data, ParamName,ParamValue, ...) % % This funciton estimates the intrinsic dimensionality of a dataset using % the Levina's Maximum Likelihood Estimator (MLE). % % Parameters % ---------- % IN: % data = Matrix DxN containing the dataset as columnwise points. % Named parameters: % 'k' % The number of neighbours to be used, must be >= 1. (def=10) % 'dists' % Precomputed knn distanced as returned by the KNN function for k % set to k+2 with respect to the given k parameter (k is set to % "size(dists,2)-2" if knn is present). % 'normalized' % Are the dists already normalized? Normalized here means that all % the k+2 distances are normalized wrt the last one and the first % (zero) and the last (one) are removed. (def=true) % OUT: % d = Estimated intrinsic dimensionality of the dataset. % ds = Local estimations. % % Examples % -------- % X = randsphere(30,1000,1); % V = linSubspSpanOrthonormalize(randn(120,30)); % pts = V*X; % [inds,dists] = KNN(pts,12,true); % d = MLE(pts,'dists',dists); % Checking parameters: if nargin<1; error('A dataset is required'); end % Infos: N = size(data,2); % Default parameters: params = struct('k',{10},'normalized',{true}); % Parsing named parameters: if nargin > 1 try params = parseParamsNamed(varargin, false, params); catch e error('Pairs ''name,value'' are required as arguments after the dataset.'); end end % Is the knn already there? mustDoKnn = true; if isfield(params,'dists') % Extracting and checking: dists = params.dists; if size(dists,1) == N && size(dists,2) > 2 % KNN already present: params.k = size(dists,2)-2; mustDoKnn = false; end end % Finding the KNNs: if mustDoKnn if not(isfield(params,'k')) error('One parameter of ''dists'' or ''k'' is required.'); end [~,dists] = KNN(data,params.k,true); end % Normalizing the distances: if not(mustDoKnn || params.normalized) % Cropping the dists: dists = dists(:,2:end); % Normalizing and removing the last element: dists = dists(:,1:end-1)./repmat(dists(:,end),[1,size(dists,2)-1]); end % Local estimation: ds = -1./mean(log(dists),2); % Harmonic mean: d = 1./mean(1./ds); end ================================================ FILE: analysis/intrinsic_dimension_estimation/matlab_codes/MiND_KL.m ================================================ function [d,kl] = MiND_KL(data,varargin) %MiND_KL Intrinsic dimensionality using the MiND_KL technique % % function [d,kl] = MiND_KL(data, ParamName,ParamValue, ...) % % This funciton estimates the intrinsic dimensionality of a dataset using % the Minimum Neighbor Distance estimators described in: % % "Minimum Neighbor Distance Estimators of Intrinsic Dimension", % A.Rozza, G.Lombardi, C.Ceruti, E.Casiraghi, P.Campadelli, % Published to the European Conference of Machine Learning (ECML 2011). % % Parameters % ---------- % IN: % data = Matrix DxN containing the dataset as columnwise points. % Named parameters: % 'k' % The number of neighbours to be used, must be >= 1. (def=10) % 'minDim' % The minimal number of dimensions. (def=1) % 'maxDim' % The maximal number of dimensions. (def=size(data,1)) % 'sigma' % Standard deviation of the gaussian smoothing filter executed on % the estimated divergences. (def=[]=no smoothing) % 'dists' % Precomputed knn distanced as returned by the KNN function for k % set to k+2 with respect to the given k parameter (k is set to % "size(dists,2)-2" if knn is present). % 'normalized' % Are the dists already normalized? Normalized here means that all % the k+2 distances are normalized wrt the last one and the first % (zero) and the last (one) are removed. (def=true) % OUT: % d = Estimated intrinsic dimensionality of the dataset. % kl = The Kullback Leibler divergences curve for d=1..D. % % Examples % -------- % X = randsphere(30,1000,1); % V = linSubspSpanOrthonormalize(randn(120,30)); % pts = V*X; % [inds,dists] = KNN(pts,12,true); % [d,kl] = MiND_KL(pts,'dists',dists); % Checking parameters: if nargin<1; error('A dataset is required'); end % Infos: [D,N] = size(data); % Default parameters: params = struct('k',{10}, ... 'normalized',{true}, ... 'minDim',{1}, 'maxDim',{D}, ... 'sigma',[]); % Parsing named parameters: if nargin > 1 try params = parseParamsNamed(varargin, false, params); catch e error('Pairs ''name,value'' are required as arguments after the dataset.'); end end % Is the knn already there? mustDoKnn = true; if isfield(params,'dists') % Extracting and checking: dists = params.dists; if size(dists,1) == N && size(dists,2) > 2 % KNN already present: params.k = size(dists,2)-2; mustDoKnn = false; end end % Finding the KNNs: if mustDoKnn if not(isfield(params,'k')) error('One parameter of ''dists'' or ''k'' is required.'); end [~,dists] = KNN(data,params.k+2); end % Normalizing the distances: if mustDoKnn || not(params.normalized) % Cropping the dists: dists = dists(:,2:end); % Normalizing and removing the last element: dists = dists(:,1:end-1)./repmat(dists(:,end),[1,size(dists,2)-1]); end % Only the first normalized neighbour is required: k = size(dists,2); dists = dists(:,1)'; % Executing the experiments with unit uniformly-sampled hyper-spheres: numDims = params.maxDim - params.minDim + 1; kl = zeros(1,numDims); for i=1:numDims % Generating the data: data2 = randsphere(params.minDim + i - 1, N); % Computing the normalized knn distances: [~,distsExp] = KNN(data2,k,true); distsExp = distsExp(:,1)'; % Estimating the Kullback-Leibler divergence: kl(i) = wangKL1d(distsExp,dists); end % Smoothing if required: if not(isempty(params.sigma)) % Preparing the impulse response: rad = 4*ceil(params.sigma/2); h = gauss(-rad:rad,params.sigma,0,false); h = h/sum(h); % Filtering: kl = conv([flipdim(kl,2),kl,kl],h,'same'); kl = kl(numDims+1:end-numDims); end % Estimating the id: [~,d] = min(kl); d = d + params.minDim - 1; end % ------------------------ LOCAL FUNCTIONS ------------------------ % Wang's Kullback-Leibler divergence estimator in one dimension: function div = wangKL1d(data1,data2) % Getting data: n = numel(data1); m = numel(data2); data1 = sort(data1); % Generating the nearest neighbors for data1: nnData1 = [-inf,data1,inf]; nnData1 = abs(nnData1(1:end-1)-nnData1(2:end)); nnData1 = min([nnData1(1:end-1);nnData1(2:end)],[],1); % Generating the nearest neighbors for data2: nnData2 = min(abs(repmat(data1,[m,1]) - repmat(data2',[1,n])),[],1); % Computing the kl div: div = abs(sum(log(nnData2./(nnData1+eps)))/n + log(m/(n-1))); end ================================================ FILE: analysis/intrinsic_dimension_estimation/matlab_codes/MiND_ML.m ================================================ function d = MiND_ML(data,varargin) %MiND_ML Intrinsic dimensionality using the MiND_ML techniques % % function d = MiND_ML(data, ParamName,ParamValue, ...) % % This funciton estimates the intrinsic dimensionality of a dataset using % the Minimum Neighbor Distance estimators described in: % % "Minimum Neighbor Distance Estimators of Intrinsic Dimension", % A.Rozza, G.Lombardi, C.Ceruti, E.Casiraghi, P.Campadelli, % Published to the European Conference of Machine Learning (ECML 2011). % % Parameters % ---------- % IN: % data = Matrix DxN containing the dataset as columnwise points. % Named parameters: % 'k' % The number of neighbours to be used, must be >= 1. (def=10) % 'optimize' % Must an optimization step be executed to estimate a fractal id, % that is MiND_MLK instead of MiND_MLI? (def=false) % 'dists' % Precomputed knn distanced as returned by the KNN function for k % set to k+2 with respect to the given k parameter (k is set to % "size(dists,2)-2" if knn is present). % 'normalized' % Are the dists already normalized? Normalized here means that all % the k+2 distances are normalized wrt the last one and the first % (zero) and the last (one) are removed. (def=true) % OUT: % d = Estimated intrinsic dimensionality of the dataset. % % Examples % -------- % X = randsphere(30,1000,1); % V = linSubspSpanOrthonormalize(randn(120,30)); % pts = V*X; % [inds,dists] = KNN(pts,12,true); % d = MiND_ML(pts,'dists',dists); % Checking parameters: if nargin<1; error('A dataset is required'); end % Infos: [D,N] = size(data); % Default parameters: params = struct('k',{10},'optimize',{false},'normalized',{true}); % Parsing named parameters: if nargin > 1 try params = parseParamsNamed(varargin, false, params); catch e error('Pairs ''name,value'' are required as arguments after the dataset.'); end end % Is the knn already there? mustDoKnn = true; if isfield(params,'dists') % Extracting and checking: dists = params.dists; if size(dists,1) == N && size(dists,2) > 2 % KNN already present: params.k = size(dists,2)-2; mustDoKnn = false; end end % Finding the KNNs: if mustDoKnn if not(isfield(params,'k')) error('One parameter of ''dists'' or ''k'' is required.'); end [~,dists] = KNN(data,params.k,true); end % Normalizing the distances: if not(mustDoKnn || params.normalized) % Cropping the dists: dists = dists(:,2:end); % Normalizing and removing the last element: dists = dists(:,1:end-1)./repmat(dists(:,end),[1,size(dists,2)-1]); end % Infos: [n,k] = size(dists); % The first normalized distance: r = dists(:,1); % The log likelihood derivative squared: fun = @(d)((n/d)+sum(log(r)-((k-1).*(log(r).*r.^d))./(1-r.^d)))^2; % Evaluating it: vals = zeros(1,D); for d=1:D; vals(d)=fun(d); end % MiND_MLI result: [~,d] = min(vals); % The optimization: if params.optimize opts = optimset('Display','off','TolX',1e-3); d = fminsearch(fun,d,opts); end end ================================================ FILE: analysis/intrinsic_dimension_estimation/matlab_codes/demo_idEstimation.m ================================================ %% Intrinsic Dimensionality (id) estimation % % In the past decade the development of automatic techniques to estimate % the intrinsic dimensionality of a given dataset has gained considerable % attention due to its relevance in several application elds. % % In this small toolbox some of the state-of-art techniques are implemented % as functions that use a "standard" interface to use them, thus allowing % client code to be decoupled from the knowledge of the applied technique. % % For more details see/cite: % % "Minimum Neighbor Distance Estimators of Intrinsic Dimension", % A.Rozza, G.Lombardi, C.Ceruti, E.Casiraghi, P.Campadelli, % Published to the European Conference of Machine Learning (ECML 2011). % % "DANCo: Dimensionality from Angle and Norm Concentration", % C. Ceruti, S. Bassis, A. Rozza, G. Lombardi, E. Casiraghi, P. Campadelli % arXiv:1206.3881v1, http://arxiv.org/abs/1206.3881v1 %% Usage of this id-estimation toolbox % % As an example here we generate a simple dataset and test all the % implemented techniques on it to compare the results, the dataset is an % hypersphere in R^{10} linearly embedded in R^{20}. % Generating the dataset: d = 10; D = 20; X = randsphere(d,500); V = linSubspSpanOrthonormalize(randn(D,d)); data = V*X; % A first base estimation: fprintf('MLE => %2.2f\n', MLE(data)); fprintf('DANCoFit => %2.2f\n', DANCoFit(data)); %% % The techniques: estimators = {'MLE','MiND_ML','MiND_KL','DANCo','DANCoFit'}; % Initializing the results: idEst = zeros(1,numel(estimators)); spentTime = zeros(1,numel(estimators)); % Running the experiments: for i = 1:numel(estimators) % Obtaining the funciton: estimator = str2func(estimators{i}); % Estimating: tic; idEst(i) = estimator(data); spentTime(i) = toc; end % Plotting the results: figure; bar([idEst(:),spentTime(:)]); hold on; plot([0,numel(estimators)+1],[d,d],'b:'); set(gca,'xticklabel',estimators); legend({'Estimated id','Spent time'},'Location','NorthWest'); title(sprintf('Estimating the id of a %dd hypersphere linearly embedded in R^{%d}',d,D)); %% Estimators based on the Kullback-Leibler divergence % % Some of the presented algorithms estimate the Kullback-Leibler divergence % between the distribution of some statistics computed on the real dataset % and those computed on synthetically-cenerated data (hyperspheres) or by % means of fitting functions (trained using the statistics produced by % synthetic data). Here the results obtained by some of them. % Initialization: estimators = {'MiND_KL','DANCo','DANCoFit'}; % Initializing the results: idEst = zeros(1,numel(estimators)); spentTime = zeros(1,numel(estimators)); kls = zeros(numel(estimators),D); % Running the experiments: for i = 1:numel(estimators) % Obtaining the funciton: estimator = str2func(estimators{i}); % Estimating: tic; [idEst(i),kls(i,:)] = estimator(data); spentTime(i) = toc; end % Notice that DANCo produces unreliable results for d=1: Ds = 2:D; kls2 = kls(:,Ds); % Plotting the results: figure; plot(Ds,kls2'); hold on; plot([idEst;idEst],[zeros(size(idEst));max(kls2(:))*ones(size(idEst))],':'); legend(estimators,'Location','NorthEast'); title('KL-based id estimatros: the estimated KLs'); ================================================ FILE: analysis/intrinsic_dimension_estimation/matlab_codes/gauss.m ================================================ function func = gauss(x, sigma, mu, norm) % GAUSS Returns the monodimensional gaussian function % % Return the evaluation of the gaussian function on an array of points % % Params: % % x: The array of x points (def=[-10:10]) % sigma: The standard deviation (def=2.5) % mu: The mean of th e gaussian (def=0) % norm: Must be normalized? (def=true) % Check x if nargin<1 x = -10:10; end % Check sigma if nargin<2 sigma = 2.5; end % Check mean if nargin<3 mu = 0; end % Check norm if nargin<4 norm = true; end % Computing the gaussian funcion if norm func = exp(-(x-mu).^2/(2*sigma^2))/(sigma*sqrt(2*pi)); else func = exp(-(x-mu).^2/(2*sigma^2)); end ================================================ FILE: analysis/intrinsic_dimension_estimation/matlab_codes/html/demo_idEstimation.html ================================================ Intrinsic Dimensionality (id) estimation

Intrinsic Dimensionality (id) estimation

In the past decade the development of automatic techniques to estimate the intrinsic dimensionality of a given dataset has gained considerable attention due to its relevance in several application elds.

In this small toolbox some of the state-of-art techniques are implemented as functions that use a "standard" interface to use them, thus allowing client code to be decoupled from the knowledge of the applied technique.

For more details see/cite:

"Minimum Neighbor Distance Estimators of Intrinsic Dimension",
A.Rozza, G.Lombardi, C.Ceruti, E.Casiraghi, P.Campadelli,
Published to the European Conference of Machine Learning (ECML 2011).
"DANCo: Dimensionality from Angle and Norm Concentration",
C. Ceruti, S. Bassis, A. Rozza, G. Lombardi, E. Casiraghi, P. Campadelli
arXiv:1206.3881v1, http://arxiv.org/abs/1206.3881v1

Contents

Usage of this id-estimation toolbox

As an example here we generate a simple dataset and test all the implemented techniques on it to compare the results, the dataset is an hypersphere in R^{10} linearly embedded in R^{20}.

% Generating the dataset:
d = 10; D = 20;
X = randsphere(d,500);
V = linSubspSpanOrthonormalize(randn(D,d));
data = V*X;

% A first base estimation:
fprintf('MLE      => %2.2f\n', MLE(data));
fprintf('DANCoFit => %2.2f\n', DANCoFit(data));
MLE      => 8.07
DANCoFit => 10.00
% The techniques:
estimators = {'MLE','MiND_ML','MiND_KL','DANCo','DANCoFit'};

% Initializing the results:
idEst = zeros(1,numel(estimators));
spentTime = zeros(1,numel(estimators));

% Running the experiments:
for i = 1:numel(estimators)
    % Obtaining the funciton:
    estimator = str2func(estimators{i});

    % Estimating:
    tic;
    idEst(i) = estimator(data);
    spentTime(i) = toc;
end

% Plotting the results:
figure; bar([idEst(:),spentTime(:)]); hold on;
plot([0,numel(estimators)+1],[d,d],'b:');
set(gca,'xticklabel',estimators);
legend({'Estimated id','Spent time'},'Location','NorthWest');
title(sprintf('Estimating the id of a %dd hypersphere linearly embedded in R^{%d}',d,D));

Estimators based on the Kullback-Leibler divergence

Some of the presented algorithms estimate the Kullback-Leibler divergence between the distribution of some statistics computed on the real dataset and those computed on synthetically-cenerated data (hyperspheres) or by means of fitting functions (trained using the statistics produced by synthetic data). Here the results obtained by some of them.

% Initialization:
estimators = {'MiND_KL','DANCo','DANCoFit'};

% Initializing the results:
idEst = zeros(1,numel(estimators));
spentTime = zeros(1,numel(estimators));
kls = zeros(numel(estimators),D);

% Running the experiments:
for i = 1:numel(estimators)
    % Obtaining the funciton:
    estimator = str2func(estimators{i});

    % Estimating:
    tic;
    [idEst(i),kls(i,:)] = estimator(data);
    spentTime(i) = toc;
end

% Notice that DANCo produces unreliable results for d=1:
Ds = 2:D;
kls2 = kls(:,Ds);

% Plotting the results:
figure; plot(Ds,kls2'); hold on;
plot([idEst;idEst],[zeros(size(idEst));max(kls2(:))*ones(size(idEst))],':');
legend(estimators,'Location','NorthEast');
title('KL-based id estimatros: the estimated KLs');
================================================ FILE: analysis/intrinsic_dimension_estimation/matlab_codes/linSubspSpanOrthonormalize.m ================================================ function Vo = linSubspSpanOrthonormalize(V) % LINSUBSPSPANORTHONORMALIZE Orthonormalizes a linear subspace span matrix. % % function Vo = linSubspSpanOrthonormalize(V) % % This funciton takes a DxN matrix that represents a set of N linear % independent vectors of a D-dimensional space, and that identifies an % N-dimensional linear subspace, and orthonormalizes them so that the % subspace spanned is the same but the vectors are versors orthogonal one % each other. The solutions to this problem are infinitely many, the one % given is the one that normalizes the first vector, orthonormalizes the % second with respect to the plane containing the first two vectors, % orthonormalizes the third with respect to the 3D space spanned by the % first three vectors and so on. Obviously must be 1<=N<=D. % % Parameters % ---------- % IN: % V = The DxN matrix containing the N vectors. % OUT: % Vo = The orthonormalized linear subspace base. % % Pre % --- % - The columns of V must be linear independent (i.e. rank(V')==N). % - The columns of V must be >=1 and <=D. % % Post % ---- % - The returned Vo contains columns that are normal % (i.e. norm(V(:,i)==1). % - The returned Vo contains columns that are orthogonal % (i.e. dot(V(:,i),V(:,j))==0 for i~=j). % Check params: if size(V,2)<1 || size(V,2)>size(V,1) error('N columns must be >=1 and <=D rows!'); end % Check for linearly dependence: if rank(V')~=size(V,2) error('The V vectors must be linearly independent!'); end % Normalizing the V vectors: normV = sqrt(sum(V.^2)); normV(normV==0) = eps; V = V./repmat(normV,[size(V,1),1]); % Orthogonalizing: for i=2:size(V,2) % Computing the projection on the subspace: proj = V(:,1)*dot(V(:,1),V(:,i)); for j=2:i-1 % The projection of the Vi vector over the Vj orthonormal versor: proj = proj + V(:,j)*dot(V(:,j),V(:,i)); end % Removing and normalizing: V(:,i) = V(:,i) - proj; V(:,i) = V(:,i)./norm(V(:,i)); end % Returning: Vo = V; ================================================ FILE: analysis/intrinsic_dimension_estimation/matlab_codes/parseParamsNamed.m ================================================ function pars = parseParamsNamed(parsCell,recurse,parsi,multi) % PARSEPARAMSNAMED Getting named parameters as a structure. % % function pars = parseParamsNamed(parsCell,recurse,parsi,multi) % % This function allows to get named parameters as a struct object. Named % parameters are sequences of strings and values pair where the string % represents the name of the parameter and the value its value. % % Parameters % ---------- % IN: % parsCell = The cell containing the parameters (like varargin). % recurse = Must the parsing recurse in sub-cells? (def=false) % parsi = Initial struct (with default values). (def=struct) % multi = Must the multi-valued fields be grouped in a cell? (def=false) % OUT: % pars = The structure containing the parameters. % Chck parameters: if nargin<1 || ~isa(parsCell,'cell') error('A cell with parameters must be passed as parameter!'); end if nargin<2; recurse=false; end if nargin<3; parsi=struct; end if nargin<4; multi=false; end % Check for the size: N = numel(parsCell); if mod(N,2)~=0 error('The cell must contain an even number of elements!'); end % Init: pars = parsi; % Iterating on parameters: for i=1:2:N % Get the name: name = parsCell{i}; if ~isa(name,'char') error('The name of parameters must be a string!'); end % Getting the value: value = parsCell{i+1}; if recurse && isa(value,'cell') % Generating the substructure: value = parseParamsNamed(value,recurse); end % Checking if multi: if multi % Init: if not(isfield(pars,name)) pars.(name) = {}; end if not(isa(pars.(name),'cell')) pars.(name) = {pars.(name)}; end % Packing: pars.(name){end+1} = value; else % Saving the last: pars.(name) = value; end end ================================================ FILE: analysis/intrinsic_dimension_estimation/matlab_codes/private/DANCo_estimateKL.m ================================================ function kl = DANCo_estimateKL(k,dHat,mu,tau,dHat_ref,mu_ref,tau_ref) %DANCO_ESTIMATEKL Estimating the Kullback-Leibler divergence for DANCo % % function kl = DANCo_estimateKL(k,dHat,mu,tau,dHat_ref,mu_ref,tau_ref) % % This function estimates the Kullback-Leibler divergence between the % distances/angles bivariate distributions over two datasets whose % statistic parameters are {dHat,mu,tau} and {dHat_ref,mu_ref,tau_ref} % respectively. % % Parameters % ---------- % IN: % k = The KNN parameter. % dHat,mu,tau = Statistics for the real dataset. % dHat_ref,mu_ref,tau_ref = Statistics for the synthetic dataset. % Checking parameters: if nargin<7; error('Too few parameters'); end % Doing the computation: kl = EstimateKL_dists(dHat,dHat_ref,k) + ... EstimateKL_angles(mu,tau,mu_ref,tau_ref); end % ------------------------ LOCAL FUNCTIONS ------------------------ % Estimating the KL-divergence for all the cases d=1:D: function kl = EstimateKL_angles(m1,k1,m2,k2) % The bessel values: bk2 = besseli(0,k2); bk1 = besseli(0,k1); ak1 = besseli(1,k1); ck1 = besseli(1,-k1); % Checking for infs: if(isinf(k1)); k1 = realmax; end k2(isinf(k2)) = realmax; bk2(isinf(bk2)) = realmax; if(isinf(bk1)); bk1 = realmax; end if(isinf(ak1)); ak1 = realmax; end if(isinf(ck1)); ck1 = realmax; end % Computing the KL: kl = abs(log(bk2./bk1) + (ak1-ck1)./(2*bk1).*(k1-k2.*cos(m2-m1))); end % ----------------------------------------------------------------- % Estimating the KL-divergence for the distances: function kl = EstimateKL_dists(d1,d2,k) % Initializing: bin = zeros(numel(d2),k+1); % Computing the terms in the summary: for i=1:k+1 bin(:,i) = (-1)^(i-1)*nchoosek(k,i-1).*psi(1+(i-1)*(d1./d2)); end % Estimating the KL in cloised form: kl = -(1+harmonic(k-1)-harmonic(k)*(d2./d1)+log(d2./d1)+(k-1)*sum(bin,2)'); end % ----------------------------------------------------------------- % Harmonic numbers (by David Terr): function h = harmonic(z) if z == 1 h = 1; else h = log(z) + 0.5772196649 + 1/(2*z) - 1/(12*z^2) + 1/(120*z^4) - 1/(252*z^6); end end ================================================ FILE: analysis/intrinsic_dimension_estimation/matlab_codes/private/DANCo_statistics.m ================================================ function [isLowDim,dHat,mu,tau] = DANCo_statistics(data,k,params,force) %DANCO_STATISTICS Computes the statistics used by DANCo % % function [isLowDim,dHat,mu,tau] = DANCo_statistics(data,k,params,force) % % This funciton allows to compute the statistic parameters used by DANCo % to estimate the intrinsic dimensionality, in cas of isLowDim==true only % the dHat value is computed (the other output values set to 0 in this % case). % % Parameters % ---------- % IN: % data = The dataset as a matrix with columnwise vectors. % k = KNN parameter. % params = The named parameters passed to DANCo (see DANCo or DANCoFit). % force = Forces the computation of mu and tau. (def=false). % OUT: % isLowDim = True if the intrinsic dimensionality is considered low, in % this case the dHat estimation is considered good enough. % dHat = Estimated id by means of biased esimators like MLE or MiND. % mu,tau = Von Mises circular statistic parameters. % Checking parameters: if nargin<3; error('A dataset, k, and the parameters are required'); end if nargin<4; force=false; end if not(isa(params,'struct')); error('Params must be a struct'); end if not(isfield(params,'mlMethod')); error('An mlMethod is required'); end if not(isfield(params,'minCorrectionDim')); error('A minCorrectionDim is required'); end % Infos: N = size(data,2); % Is the knn already there? mustDoKnn = true; if isfield(params,'inds') && isfield(params,'dists') % Extracting and checking: inds = params.inds; dists = params.dists; if size(dists,1) == N && size(dists,2) >= k+2 && size(inds,1) == N && size(inds,2) >= k+2 % KNN already present: dists = dists(:,1:k+2); inds = inds(:,1:k+2); mustDoKnn = false; end end % Finding the KNNs only if needed: if mustDoKnn % It's needed! [inds,dists] = KNN(data,k,true); end % Normalizing the distances: if not(mustDoKnn || (isfield(params,'normalized') && params.normalized)) % Cropping both dists and inds: dists = dists(:,2:end); inds = inds(:,2:end); % Normalizing and removing the last element: dists = dists(:,1:end-1)./repmat(dists(:,end),[1,size(dists,2)-1]); end % Estimating dHat: switch params.mlMethod case 'MLE' dHat = MLE(data,'dists',dists,'normalized',true); case 'MiND_MLI' dHat = MiND_ML(data,'dists',dists,'optimize',false,'normalized',true); case 'MiND_MLK' dHat = MiND_ML(data,'dists',dists,'optimize',true,'normalized',true); otherwise error(['Unknown mlMethod: ' params.mlMethod]); end % Correction based on the KL-divergence: isLowDim = dHat < params.minCorrectionDim; if not(isLowDim) || force % The circular statistics: angles = ComputeAngles(data,inds); [mu,tau] = CircStats(angles); % Averaging the parameters: mu = angle(sum(exp(1i*mu))); tau = mean(tau); else % Fake parameters: mu = 0; tau = 0; end end % ------------------------ LOCAL FUNCTIONS ------------------------ % Computing the used angular statistics: function [mu,tau] = CircStats(angles) % Computing the complex circular means: cm = sum(exp(1i*angles),2); % Computing the von Mises parameters: mu = angle(cm); r = abs(cm)/size(angles,2); % The R function: if r < 0.53; tau = 2*r + r.^3 + 5*r.^5./6; else if r < 0.85; tau = -0.4 + 1.39*r + 0.43./(1-r); else tau = 1./(r.^3 - 4*r.^2 + 3*r); end end % Correcting for small sample sizes: if size(angles,2) < 15 mask = tau < 2; tau(mask) = max(tau(mask)-2./(size(angles,2).*tau(mask)),0); tau(~mask) = (size(angles,2)-1).^3.*tau(~mask)./(size(angles,2).^3+size(angles,2)); end end % ----------------------------------------------------------------- % Computing the pairwase angles for neighbours: function angles = ComputeAngles(data,inds) % Init: K = size(inds,2); angles = zeros(size(inds,1),K*(K-1)/2); cnk = combnk(1:size(inds,2),2); % Iterating on points: for i=1:size(data,2) % Translating and normalizing points: pts = data(:,inds(i,:)) - repmat(data(:,i),[1,K]); pts = pts./repmat(sqrt(sum(pts.^2,1)),[size(pts,1),1]); % Computing all the pairwise angles: a = pts(:,cnk(:,1)); b = pts(:,cnk(:,2)); angles(i,:) = atan2(sqrt(sum((b-repmat(sum(a.*b,1),[size(a,1),1]).*a).^2,1)),sum(a.*b,1)); end end ================================================ FILE: analysis/intrinsic_dimension_estimation/matlab_codes/randsphere.m ================================================ function X = randsphere(d,N,r) % RANDSPHERE Random points uniformly drawn from a spere with radius r % % function X = randsphere(d,N,r) % % This function generates random data samples from a uniform hypersphere % with a given radius. The algorithm is the one reported by Roger Stafford % on matlabcentral/fileexchange. % % Parameters % ---------- % IN: % d = Space dimensionality. (def=2) % N = Number of points. (def=1000) % r = Hypersphere radius. (def=1) % OUT: % X = Matrix of points (columnwise). % Checking parameters: if nargin<1; d=2; end if nargin<2; N=1000; end if nargin<3; r=1; end % Gaussian randomly sampled points: X = randn(d,N); % Squared norms: s2 = sum(X.^2,1); % Correcting the points norms: X = X.*repmat(r*(gammainc(s2/2,d/2).^(1/d))./sqrt(s2),d,1); end ================================================ FILE: analysis/intrinsic_dimension_estimation/methods.py ================================================ import numpy as np import os from sklearn.neighbors import NearestNeighbors def kNN(X, n_neighbors, n_jobs): neigh = NearestNeighbors(n_neighbors=n_neighbors, n_jobs=n_jobs).fit(X) dists, inds = neigh.kneighbors(X) return dists, inds def Levina_Bickel(X, dists, k): m = np.log(dists[:, k:k+1] / dists[:, 1:k]) m = (k-2) / np.sum(m, axis=1) dim = np.mean(m) return dim def start_matlab_engine(): import matlab.engine eng = matlab.engine.start_matlab() eng.addpath(os.path.join(os.path.split(__file__)[0], 'matlab_codes')) return eng def MiND_ML(X, dists, k): import matlab.engine eng = start_matlab_engine() X_mat = matlab.double(X.T.tolist()) dists_mat = matlab.double(dists[:, :k+2].tolist()) dim = eng.MiND_ML(X_mat, 'dists', dists_mat, 'normalized', False, 'optimize', True) return dim def MiND_KL(X, dists, k, maxDim=30): import matlab.engine eng = start_matlab_engine() X_mat = matlab.double(X.T.tolist()) dists_mat = matlab.double(dists[:, :k+2].tolist()) dim = eng.MiND_KL(X_mat, 'k', matlab.double([k]), 'maxDim', matlab.double([maxDim]), 'dists', dists_mat, 'normalized', False, nargout=1) return dim def DANCo(X, dists, inds, k, maxDim=30): import matlab.engine eng = start_matlab_engine() X_mat = matlab.double(X.T.tolist()) dists_mat = matlab.double(dists[:, :k+2].tolist()) inds_mat = matlab.int32((inds[:, :k+2]+1).tolist()) # fit Matlab indices dim = eng.DANCo(X_mat, 'k', matlab.double([k]), 'maxDim', matlab.double([maxDim]), 'fractal', True, 'dists', dists_mat, 'inds', inds_mat, 'normalized', False, nargout=1) return dim def Hein_CD(X): import matlab.engine eng = start_matlab_engine() X_mat = matlab.double(X.T.tolist()) dim = eng.GetDim(X_mat, nargout=1) dim = np.array(dim)[0] return dim[0], dim[1] ================================================ FILE: analysis/latent_regression/__init__.py ================================================ from .regressors import * ================================================ FILE: analysis/latent_regression/regressors.py ================================================ import numpy as np from sklearn.preprocessing import StandardScaler from sklearn.decomposition import PCA from sklearn.neural_network import MLPRegressor from sklearn.linear_model import LinearRegression def mlp_regress(X_train, X_test, y_train, y_test, random_state=0): scaler = StandardScaler() nn = MLPRegressor(hidden_layer_sizes=(64, 64, 64, 64, 64), activation='relu', solver='adam', learning_rate='adaptive', alpha=0.01, random_state=random_state) X_train_scaled = scaler.fit_transform(X_train) X_test_scaled = scaler.transform(X_test) y_norm = np.max(np.abs(y_train)) nn.fit(X_train_scaled, y_train / y_norm) train_error = np.mean(np.abs(nn.predict(X_train_scaled) * y_norm - y_train)) test_error = np.mean(np.abs(nn.predict(X_test_scaled) * y_norm - y_test)) return train_error, test_error, {'scaler':scaler, 'model':nn} def lin_regress(X_train, X_test, y_train, y_test): reg = LinearRegression() reg.fit(X_train, y_train) train_error = np.mean(np.abs(reg.predict(X_train) - y_train)) test_error = np.mean(np.abs(reg.predict(X_test) - y_test)) return train_error, test_error, reg def pca(X_train, X_test, num_components, random_state=0): pca = PCA(num_components, random_state=random_state) pca.fit(X_train) return pca.transform(X_train), pca.transform(X_test), pca ================================================ FILE: configs/air_dancer/latentpred/config1.yaml ================================================ lr: 0.01 seed: 1 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'latent-prediction' data_filepath: './data/' dataset: 'air_dancer' lr_schedule: [20, 50, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/air_dancer/latentpred/config2.yaml ================================================ lr: 0.01 seed: 2 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'latent-prediction' data_filepath: './data/' dataset: 'air_dancer' lr_schedule: [20, 50, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/air_dancer/latentpred/config3.yaml ================================================ lr: 0.02 seed: 3 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'latent-prediction' data_filepath: './data/' dataset: 'air_dancer' lr_schedule: [20, 50, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/air_dancer/model/config1.yaml ================================================ lr: 0.001 seed: 1 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'encoder-decoder' data_filepath: './data/' dataset: 'air_dancer' lr_schedule: [20, 50, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/air_dancer/model/config2.yaml ================================================ lr: 0.001 seed: 2 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'encoder-decoder' data_filepath: './data/' dataset: 'air_dancer' lr_schedule: [20, 50, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/air_dancer/model/config3.yaml ================================================ lr: 0.001 seed: 3 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'encoder-decoder' data_filepath: './data/' dataset: 'air_dancer' lr_schedule: [20, 50, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/air_dancer/model64/config1.yaml ================================================ lr: 0.001 seed: 1 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'encoder-decoder-64' data_filepath: './data/' dataset: 'air_dancer' lr_schedule: [20, 50, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/air_dancer/model64/config2.yaml ================================================ lr: 0.001 seed: 2 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'encoder-decoder-64' data_filepath: './data/' dataset: 'air_dancer' lr_schedule: [20, 50, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/air_dancer/model64/config3.yaml ================================================ lr: 0.001 seed: 3 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'encoder-decoder-64' data_filepath: './data/' dataset: 'air_dancer' lr_schedule: [20, 50, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/air_dancer/refine64/config1.yaml ================================================ lr: 0.0006 seed: 1 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'refine-64' data_filepath: './data/' dataset: 'air_dancer' lr_schedule: [15, 30, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/air_dancer/refine64/config2.yaml ================================================ lr: 0.0003 seed: 2 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'refine-64' data_filepath: './data/' dataset: 'air_dancer' lr_schedule: [15, 30, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/air_dancer/refine64/config3.yaml ================================================ lr: 0.0004 seed: 3 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'refine-64' data_filepath: './data/' dataset: 'air_dancer' lr_schedule: [15, 30, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/circular_motion/model/config1.yaml ================================================ lr: 0.001 seed: 1 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'encoder-decoder' data_filepath: './data/' dataset: 'circular_motion' lr_schedule: [20, 50, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/circular_motion/model/config2.yaml ================================================ lr: 0.001 seed: 2 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'encoder-decoder' data_filepath: './data/' dataset: 'circular_motion' lr_schedule: [20, 50, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/circular_motion/model/config3.yaml ================================================ lr: 0.001 seed: 3 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'encoder-decoder' data_filepath: './data/' dataset: 'circular_motion' lr_schedule: [20, 50, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/circular_motion/model64/config1.yaml ================================================ lr: 0.001 seed: 1 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'encoder-decoder-64' data_filepath: './data/' dataset: 'circular_motion' lr_schedule: [20, 50, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/circular_motion/model64/config2.yaml ================================================ lr: 0.001 seed: 2 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'encoder-decoder-64' data_filepath: './data/' dataset: 'circular_motion' lr_schedule: [20, 50, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/circular_motion/model64/config3.yaml ================================================ lr: 0.001 seed: 3 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'encoder-decoder-64' data_filepath: './data/' dataset: 'circular_motion' lr_schedule: [20, 50, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/circular_motion/refine64/config1.yaml ================================================ lr: 0.0003 seed: 1 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'refine-64' data_filepath: './data/' dataset: 'circular_motion' lr_schedule: [15, 30, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/circular_motion/refine64/config2.yaml ================================================ lr: 0.0003 seed: 2 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'refine-64' data_filepath: './data/' dataset: 'circular_motion' lr_schedule: [15, 30, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/circular_motion/refine64/config3.yaml ================================================ lr: 0.0003 seed: 3 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'refine-64' data_filepath: './data/' dataset: 'circular_motion' lr_schedule: [15, 30, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/double_pendulum/latentpred/config1.yaml ================================================ lr: 0.03 seed: 1 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'latent-prediction' data_filepath: './data/' dataset: 'double_pendulum' lr_schedule: [20, 50, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/double_pendulum/latentpred/config2.yaml ================================================ lr: 0.03 seed: 2 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'latent-prediction' data_filepath: './data/' dataset: 'double_pendulum' lr_schedule: [20, 50, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/double_pendulum/latentpred/config3.yaml ================================================ lr: 0.03 seed: 3 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'latent-prediction' data_filepath: './data/' dataset: 'double_pendulum' lr_schedule: [20, 50, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/double_pendulum/model/config1.yaml ================================================ lr: 0.001 seed: 1 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'encoder-decoder' data_filepath: './data/' dataset: 'double_pendulum' lr_schedule: [20, 50, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/double_pendulum/model/config2.yaml ================================================ lr: 0.001 seed: 2 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'encoder-decoder' data_filepath: './data/' dataset: 'double_pendulum' lr_schedule: [20, 50, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/double_pendulum/model/config3.yaml ================================================ lr: 0.001 seed: 3 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'encoder-decoder' data_filepath: './data/' dataset: 'double_pendulum' lr_schedule: [20, 50, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/double_pendulum/model64/config1.yaml ================================================ lr: 0.001 seed: 1 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'encoder-decoder-64' data_filepath: './data/' dataset: 'double_pendulum' lr_schedule: [20, 50, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/double_pendulum/model64/config2.yaml ================================================ lr: 0.001 seed: 2 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'encoder-decoder-64' data_filepath: './data/' dataset: 'double_pendulum' lr_schedule: [20, 50, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/double_pendulum/model64/config3.yaml ================================================ lr: 0.001 seed: 3 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'encoder-decoder-64' data_filepath: './data/' dataset: 'double_pendulum' lr_schedule: [20, 50, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/double_pendulum/refine64/config1.yaml ================================================ lr: 0.0003 seed: 1 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'refine-64' data_filepath: './data/' dataset: 'double_pendulum' lr_schedule: [15, 30, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/double_pendulum/refine64/config2.yaml ================================================ lr: 0.0005 seed: 2 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'refine-64' data_filepath: './data/' dataset: 'double_pendulum' lr_schedule: [15, 30, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/double_pendulum/refine64/config3.yaml ================================================ lr: 0.0004 seed: 3 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 256 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'refine-64' data_filepath: './data/' dataset: 'double_pendulum' lr_schedule: [15, 30, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/elastic_pendulum/latentpred/config1.yaml ================================================ lr: 0.02 seed: 1 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'latent-prediction' data_filepath: './data/' dataset: 'elastic_pendulum' lr_schedule: [20, 50, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/elastic_pendulum/latentpred/config2.yaml ================================================ lr: 0.02 seed: 2 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'latent-prediction' data_filepath: './data/' dataset: 'elastic_pendulum' lr_schedule: [20, 50, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/elastic_pendulum/latentpred/config3.yaml ================================================ lr: 0.01 seed: 3 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'latent-prediction' data_filepath: './data/' dataset: 'elastic_pendulum' lr_schedule: [20, 50, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/elastic_pendulum/model/config1.yaml ================================================ lr: 0.001 seed: 1 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'encoder-decoder' data_filepath: './data/' dataset: 'elastic_pendulum' lr_schedule: [20, 50, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/elastic_pendulum/model/config2.yaml ================================================ lr: 0.001 seed: 2 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'encoder-decoder' data_filepath: './data/' dataset: 'elastic_pendulum' lr_schedule: [20, 50, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/elastic_pendulum/model/config3.yaml ================================================ lr: 0.001 seed: 3 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'encoder-decoder' data_filepath: './data/' dataset: 'elastic_pendulum' lr_schedule: [20, 50, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/elastic_pendulum/model64/config1.yaml ================================================ lr: 0.001 seed: 1 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'encoder-decoder-64' data_filepath: './data/' dataset: 'elastic_pendulum' lr_schedule: [20, 50, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/elastic_pendulum/model64/config2.yaml ================================================ lr: 0.001 seed: 2 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'encoder-decoder-64' data_filepath: './data/' dataset: 'elastic_pendulum' lr_schedule: [20, 50, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/elastic_pendulum/model64/config3.yaml ================================================ lr: 0.001 seed: 3 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'encoder-decoder-64' data_filepath: './data/' dataset: 'elastic_pendulum' lr_schedule: [20, 50, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/elastic_pendulum/refine64/config1.yaml ================================================ lr: 0.0007 seed: 1 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'refine-64' data_filepath: './data/' dataset: 'elastic_pendulum' lr_schedule: [15, 30, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/elastic_pendulum/refine64/config2.yaml ================================================ lr: 0.0004 seed: 2 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'refine-64' data_filepath: './data/' dataset: 'elastic_pendulum' lr_schedule: [15, 30, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/elastic_pendulum/refine64/config3.yaml ================================================ lr: 0.0002 seed: 3 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 256 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'refine-64' data_filepath: './data/' dataset: 'elastic_pendulum' lr_schedule: [15, 30, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/fire/latentpred/config1.yaml ================================================ lr: 0.008 seed: 1 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'latent-prediction' data_filepath: './data/' dataset: 'fire' lr_schedule: [20, 50, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/fire/latentpred/config2.yaml ================================================ lr: 0.02 seed: 2 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'latent-prediction' data_filepath: './data/' dataset: 'fire' lr_schedule: [20, 50, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/fire/latentpred/config3.yaml ================================================ lr: 0.009 seed: 3 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'latent-prediction' data_filepath: './data/' dataset: 'fire' lr_schedule: [20, 50, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/fire/model/config1.yaml ================================================ lr: 0.001 seed: 1 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'encoder-decoder' data_filepath: './data/' dataset: 'fire' lr_schedule: [20, 50, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/fire/model/config2.yaml ================================================ lr: 0.001 seed: 2 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'encoder-decoder' data_filepath: './data/' dataset: 'fire' lr_schedule: [20, 50, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/fire/model/config3.yaml ================================================ lr: 0.001 seed: 3 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'encoder-decoder' data_filepath: './data/' dataset: 'fire' lr_schedule: [20, 50, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/fire/model64/config1.yaml ================================================ lr: 0.001 seed: 1 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'encoder-decoder-64' data_filepath: './data/' dataset: 'fire' lr_schedule: [20, 50, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/fire/model64/config2.yaml ================================================ lr: 0.001 seed: 2 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'encoder-decoder-64' data_filepath: './data/' dataset: 'fire' lr_schedule: [20, 50, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/fire/model64/config3.yaml ================================================ lr: 0.001 seed: 3 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'encoder-decoder-64' data_filepath: './data/' dataset: 'fire' lr_schedule: [20, 50, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/fire/refine64/config1.yaml ================================================ lr: 0.0007 seed: 1 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'refine-64' data_filepath: './data/' dataset: 'fire' lr_schedule: [15, 30, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/fire/refine64/config2.yaml ================================================ lr: 0.0006 seed: 2 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'refine-64' data_filepath: './data/' dataset: 'fire' lr_schedule: [15, 30, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/fire/refine64/config3.yaml ================================================ lr: 0.0007 seed: 3 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'refine-64' data_filepath: './data/' dataset: 'fire' lr_schedule: [15, 30, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/lava_lamp/latentpred/config1.yaml ================================================ lr: 0.008 seed: 1 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'latent-prediction' data_filepath: './data/' dataset: 'lava_lamp' lr_schedule: [20, 50, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/lava_lamp/latentpred/config2.yaml ================================================ lr: 0.009 seed: 2 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'latent-prediction' data_filepath: './data/' dataset: 'lava_lamp' lr_schedule: [20, 50, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/lava_lamp/latentpred/config3.yaml ================================================ lr: 0.01 seed: 3 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'latent-prediction' data_filepath: './data/' dataset: 'lava_lamp' lr_schedule: [20, 50, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/lava_lamp/model/config1.yaml ================================================ lr: 0.001 seed: 1 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'encoder-decoder' data_filepath: './data/' dataset: 'lava_lamp' lr_schedule: [20, 50, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/lava_lamp/model/config2.yaml ================================================ lr: 0.001 seed: 2 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'encoder-decoder' data_filepath: './data/' dataset: 'lava_lamp' lr_schedule: [20, 50, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/lava_lamp/model/config3.yaml ================================================ lr: 0.001 seed: 3 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'encoder-decoder' data_filepath: './data/' dataset: 'lava_lamp' lr_schedule: [20, 50, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/lava_lamp/model64/config1.yaml ================================================ lr: 0.001 seed: 1 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'encoder-decoder-64' data_filepath: './data/' dataset: 'lava_lamp' lr_schedule: [20, 50, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/lava_lamp/model64/config2.yaml ================================================ lr: 0.001 seed: 2 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'encoder-decoder-64' data_filepath: './data/' dataset: 'lava_lamp' lr_schedule: [20, 50, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/lava_lamp/model64/config3.yaml ================================================ lr: 0.001 seed: 3 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'encoder-decoder-64' data_filepath: './data/' dataset: 'lava_lamp' lr_schedule: [20, 50, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/lava_lamp/refine64/config1.yaml ================================================ lr: 0.0006 seed: 1 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'refine-64' data_filepath: './data/' dataset: 'lava_lamp' lr_schedule: [15, 30, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/lava_lamp/refine64/config2.yaml ================================================ lr: 0.0004 seed: 2 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'refine-64' data_filepath: './data/' dataset: 'lava_lamp' lr_schedule: [15, 30, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/lava_lamp/refine64/config3.yaml ================================================ lr: 0.0005 seed: 3 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'refine-64' data_filepath: './data/' dataset: 'lava_lamp' lr_schedule: [15, 30, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/reaction_diffusion/latentpred/config1.yaml ================================================ lr: 0.01 seed: 1 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'latent-prediction' data_filepath: './data/' dataset: 'reaction_diffusion' lr_schedule: [20, 50, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/reaction_diffusion/latentpred/config2.yaml ================================================ lr: 0.03 seed: 2 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'latent-prediction' data_filepath: './data/' dataset: 'reaction_diffusion' lr_schedule: [20, 50, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/reaction_diffusion/latentpred/config3.yaml ================================================ lr: 0.01 seed: 3 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'latent-prediction' data_filepath: './data/' dataset: 'reaction_diffusion' lr_schedule: [20, 50, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/reaction_diffusion/model/config1.yaml ================================================ lr: 0.001 seed: 1 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'encoder-decoder' data_filepath: './data/' dataset: 'reaction_diffusion' lr_schedule: [20, 50, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/reaction_diffusion/model/config2.yaml ================================================ lr: 0.001 seed: 2 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'encoder-decoder' data_filepath: './data/' dataset: 'reaction_diffusion' lr_schedule: [20, 50, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/reaction_diffusion/model/config3.yaml ================================================ lr: 0.001 seed: 3 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'encoder-decoder' data_filepath: './data/' dataset: 'reaction_diffusion' lr_schedule: [20, 50, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/reaction_diffusion/model64/config1.yaml ================================================ lr: 0.001 seed: 1 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'encoder-decoder-64' data_filepath: './data/' dataset: 'reaction_diffusion' lr_schedule: [20, 50, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/reaction_diffusion/model64/config2.yaml ================================================ lr: 0.001 seed: 2 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'encoder-decoder-64' data_filepath: './data/' dataset: 'reaction_diffusion' lr_schedule: [20, 50, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/reaction_diffusion/model64/config3.yaml ================================================ lr: 0.001 seed: 3 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'encoder-decoder-64' data_filepath: './data/' dataset: 'reaction_diffusion' lr_schedule: [20, 50, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/reaction_diffusion/refine64/config1.yaml ================================================ lr: 0.0004 seed: 1 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'refine-64' data_filepath: './data/' dataset: 'reaction_diffusion' lr_schedule: [15, 30, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/reaction_diffusion/refine64/config2.yaml ================================================ lr: 0.0003 seed: 2 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'refine-64' data_filepath: './data/' dataset: 'reaction_diffusion' lr_schedule: [15, 30, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/reaction_diffusion/refine64/config3.yaml ================================================ lr: 0.0003 seed: 3 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 256 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'refine-64' data_filepath: './data/' dataset: 'reaction_diffusion' lr_schedule: [15, 30, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/single_pendulum/latentpred/config1.yaml ================================================ lr: 0.005 seed: 1 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'latent-prediction' data_filepath: './data/' dataset: 'single_pendulum' lr_schedule: [20, 50, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/single_pendulum/latentpred/config2.yaml ================================================ lr: 0.01 seed: 2 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'latent-prediction' data_filepath: './data/' dataset: 'single_pendulum' lr_schedule: [20, 50, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/single_pendulum/latentpred/config3.yaml ================================================ lr: 0.02 seed: 3 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'latent-prediction' data_filepath: './data/' dataset: 'single_pendulum' lr_schedule: [20, 50, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/single_pendulum/model/config1.yaml ================================================ lr: 0.001 seed: 1 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'encoder-decoder' data_filepath: './data/' dataset: 'single_pendulum' lr_schedule: [20, 50, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/single_pendulum/model/config2.yaml ================================================ lr: 0.001 seed: 2 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'encoder-decoder' data_filepath: './data/' dataset: 'single_pendulum' lr_schedule: [20, 50, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/single_pendulum/model/config3.yaml ================================================ lr: 0.001 seed: 3 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'encoder-decoder' data_filepath: './data/' dataset: 'single_pendulum' lr_schedule: [20, 50, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/single_pendulum/model64/config1.yaml ================================================ lr: 0.001 seed: 1 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'encoder-decoder-64' data_filepath: './data/' dataset: 'single_pendulum' lr_schedule: [20, 50, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/single_pendulum/model64/config2.yaml ================================================ lr: 0.001 seed: 2 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'encoder-decoder-64' data_filepath: './data/' dataset: 'single_pendulum' lr_schedule: [20, 50, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/single_pendulum/model64/config3.yaml ================================================ lr: 0.001 seed: 3 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'encoder-decoder-64' data_filepath: './data/' dataset: 'single_pendulum' lr_schedule: [20, 50, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/single_pendulum/refine64/config1.yaml ================================================ lr: 0.0003 seed: 1 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'refine-64' data_filepath: './data/' dataset: 'single_pendulum' lr_schedule: [15, 30, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/single_pendulum/refine64/config2.yaml ================================================ lr: 0.0003 seed: 2 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'refine-64' data_filepath: './data/' dataset: 'single_pendulum' lr_schedule: [15, 30, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/single_pendulum/refine64/config3.yaml ================================================ lr: 0.0004 seed: 3 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'refine-64' data_filepath: './data/' dataset: 'single_pendulum' lr_schedule: [15, 30, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/swingstick_magnetic/model/config1.yaml ================================================ lr: 0.001 seed: 1 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'encoder-decoder' data_filepath: './data/' dataset: 'swingstick_magnetic' lr_schedule: [20, 50, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/swingstick_magnetic/model/config2.yaml ================================================ lr: 0.001 seed: 2 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'encoder-decoder' data_filepath: './data/' dataset: 'swingstick_magnetic' lr_schedule: [20, 50, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/swingstick_magnetic/model/config3.yaml ================================================ lr: 0.001 seed: 3 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'encoder-decoder' data_filepath: './data/' dataset: 'swingstick_magnetic' lr_schedule: [20, 50, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/swingstick_magnetic/model64/config1.yaml ================================================ lr: 0.001 seed: 1 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'encoder-decoder-64' data_filepath: './data/' dataset: 'swingstick_magnetic' lr_schedule: [20, 50, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/swingstick_magnetic/model64/config2.yaml ================================================ lr: 0.001 seed: 2 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'encoder-decoder-64' data_filepath: './data/' dataset: 'swingstick_magnetic' lr_schedule: [20, 50, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/swingstick_magnetic/model64/config3.yaml ================================================ lr: 0.001 seed: 3 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'encoder-decoder-64' data_filepath: './data/' dataset: 'swingstick_magnetic' lr_schedule: [20, 50, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/swingstick_magnetic/refine64/config1.yaml ================================================ lr: 0.001 seed: 1 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'refine-64' data_filepath: './data/' dataset: 'swingstick_magnetic' lr_schedule: [15, 30, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/swingstick_magnetic/refine64/config2.yaml ================================================ lr: 0.001 seed: 2 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'refine-64' data_filepath: './data/' dataset: 'swingstick_magnetic' lr_schedule: [15, 30, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/swingstick_magnetic/refine64/config3.yaml ================================================ lr: 0.001 seed: 3 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'refine-64' data_filepath: './data/' dataset: 'swingstick_magnetic' lr_schedule: [15, 30, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/swingstick_non_magnetic/latentpred/config1.yaml ================================================ lr: 0.02 seed: 1 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'latent-prediction' data_filepath: './data/' dataset: 'swingstick_non_magnetic' lr_schedule: [20, 50, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/swingstick_non_magnetic/latentpred/config2.yaml ================================================ lr: 0.03 seed: 2 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'latent-prediction' data_filepath: './data/' dataset: 'swingstick_non_magnetic' lr_schedule: [20, 50, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/swingstick_non_magnetic/latentpred/config3.yaml ================================================ lr: 0.01 seed: 3 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'latent-prediction' data_filepath: './data/' dataset: 'swingstick_non_magnetic' lr_schedule: [20, 50, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/swingstick_non_magnetic/model/config1.yaml ================================================ lr: 0.001 seed: 1 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'encoder-decoder' data_filepath: './data/' dataset: 'swingstick_non_magnetic' lr_schedule: [20, 50, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/swingstick_non_magnetic/model/config2.yaml ================================================ lr: 0.001 seed: 2 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'encoder-decoder' data_filepath: './data/' dataset: 'swingstick_non_magnetic' lr_schedule: [20, 50, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/swingstick_non_magnetic/model/config3.yaml ================================================ lr: 0.001 seed: 3 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'encoder-decoder' data_filepath: './data/' dataset: 'swingstick_non_magnetic' lr_schedule: [20, 50, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/swingstick_non_magnetic/model64/config1.yaml ================================================ lr: 0.001 seed: 1 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'encoder-decoder-64' data_filepath: './data/' dataset: 'swingstick_non_magnetic' lr_schedule: [20, 50, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/swingstick_non_magnetic/model64/config2.yaml ================================================ lr: 0.001 seed: 2 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'encoder-decoder-64' data_filepath: './data/' dataset: 'swingstick_non_magnetic' lr_schedule: [20, 50, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/swingstick_non_magnetic/model64/config3.yaml ================================================ lr: 0.001 seed: 3 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'encoder-decoder-64' data_filepath: './data/' dataset: 'swingstick_non_magnetic' lr_schedule: [20, 50, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/swingstick_non_magnetic/refine64/config1.yaml ================================================ lr: 0.0005 seed: 1 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'refine-64' data_filepath: './data/' dataset: 'swingstick_non_magnetic' lr_schedule: [15, 30, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/swingstick_non_magnetic/refine64/config2.yaml ================================================ lr: 0.0001 seed: 2 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'refine-64' data_filepath: './data/' dataset: 'swingstick_non_magnetic' lr_schedule: [15, 30, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: configs/swingstick_non_magnetic/refine64/config3.yaml ================================================ lr: 0.0004 seed: 3 if_cuda: True gamma: 0.5 log_dir: 'logs' train_batch: 512 val_batch: 256 test_batch: 256 num_workers: 8 model_name: 'refine-64' data_filepath: './data/' dataset: 'swingstick_non_magnetic' lr_schedule: [15, 30, 100, 300] num_gpus: 1 epochs: 1000 ================================================ FILE: datainfo/README.md ================================================ ## Data Collection - single_pendulum, elastic_pendulum, reaction_diffusion, circular_motion: simulation codes to generate the simulation data. - double_pendulum: post processings. - fire, lava_lamp: split long sequences. - utils: split long sequences. ## Data Structure for Using Your Own Dataset All of our datasets follow the structure below. Please prepare any new dataset as follows to use the current dataloader. If you want to use a different file organization, you will need to write your own dataloader. ``` \double_pendulum \0 \0.png \1.png \... \1 \... \2 \... ... ``` ================================================ FILE: datainfo/air_dancer/data_split_dict_1.json ================================================ { "test": [ 2, 18, 4 ], "val": [ 15, 3, 8 ], "train": [ 23, 24, 11, 10, 22, 1, 5, 25, 19, 13, 9, 17, 16, 0, 7, 21, 6, 12, 20, 14 ] } ================================================ FILE: datainfo/air_dancer/data_split_dict_2.json ================================================ { "test": [ 24, 2, 1 ], "val": [ 9, 5, 11 ], "train": [ 3, 20, 17, 15, 0, 23, 4, 7, 14, 16, 19, 22, 12, 18, 10, 13, 21, 25, 6, 8 ] } ================================================ FILE: datainfo/air_dancer/data_split_dict_3.json ================================================ { "test": [ 17, 18, 7 ], "val": [ 19, 11, 4 ], "train": [ 6, 16, 23, 14, 1, 5, 21, 10, 9, 13, 25, 12, 3, 8, 22, 20, 0, 2, 24, 15 ] } ================================================ FILE: datainfo/circular_motion/data_split_dict_1.json ================================================ { "test": [ 524, 1032, 960, 907, 977, 840, 877, 802, 392, 5, 746, 980, 623, 1068, 675, 1098, 931, 470, 361, 591, 867, 975, 352, 526, 414, 237, 561, 880, 942, 552, 1059, 789, 12, 1005, 464, 345, 348, 806, 631, 89, 961, 60, 1002, 758, 1046, 335, 221, 898, 177, 767, 751, 354, 848, 827, 497, 983, 70, 805, 1034, 1022, 581, 621, 388, 1039, 1072, 1025, 681, 247, 607, 380, 204, 852, 44, 593, 941, 448, 472, 707, 477, 1015, 896, 454, 59, 864, 443, 780, 18, 52, 45, 62, 650, 209, 468, 545, 912, 4, 886, 798, 58, 999, 192, 429, 777, 967, 920, 1014, 241, 522, 129, 275 ], "val": [ 456, 164, 736, 36, 149, 406, 1075, 230, 21, 1017, 442, 620, 214, 1096, 747, 259, 111, 264, 1089, 815, 431, 351, 395, 319, 24, 116, 485, 508, 329, 719, 465, 301, 728, 663, 279, 672, 172, 540, 1063, 163, 171, 71, 297, 1011, 189, 639, 944, 112, 1004, 255, 287, 773, 772, 14, 463, 17, 888, 85, 72, 689, 861, 33, 261, 836, 871, 816, 564, 817, 93, 881, 185, 598, 563, 181, 1079, 235, 823, 28, 614, 469, 339, 627, 1033, 638, 553, 551, 1, 1040, 424, 365, 832, 496, 423, 516, 963, 1019, 567, 583, 373, 890, 492, 57, 972, 436, 210, 574, 796, 531, 132, 828 ], "train": [ 391, 479, 51, 1077, 882, 625, 268, 197, 433, 635, 246, 685, 640, 803, 131, 398, 1070, 580, 924, 863, 962, 238, 39, 763, 653, 455, 608, 869, 680, 768, 634, 107, 514, 509, 884, 417, 219, 146, 311, 630, 996, 290, 80, 911, 366, 353, 130, 814, 1001, 765, 100, 565, 487, 498, 158, 147, 862, 887, 13, 529, 295, 444, 916, 795, 722, 260, 1016, 950, 609, 284, 895, 542, 535, 381, 450, 294, 750, 502, 494, 326, 556, 67, 1057, 384, 933, 143, 644, 362, 1084, 120, 16, 1094, 706, 788, 233, 905, 1036, 596, 733, 853, 418, 254, 119, 26, 559, 575, 549, 1088, 834, 966, 202, 434, 1095, 95, 308, 507, 590, 154, 988, 357, 1008, 425, 360, 234, 778, 819, 883, 90, 793, 611, 1028, 901, 812, 740, 432, 923, 68, 289, 167, 679, 784, 1093, 186, 671, 891, 849, 316, 729, 669, 213, 375, 801, 304, 991, 569, 9, 633, 603, 428, 190, 940, 642, 368, 771, 857, 282, 263, 419, 330, 1031, 420, 655, 876, 520, 1041, 544, 194, 791, 81, 122, 859, 3, 656, 1042, 201, 956, 753, 266, 262, 782, 243, 334, 54, 99, 341, 140, 280, 1027, 421, 787, 43, 401, 96, 115, 989, 1071, 973, 56, 734, 1061, 617, 850, 30, 928, 534, 573, 1067, 968, 949, 499, 343, 868, 990, 837, 242, 1092, 482, 1083, 992, 568, 350, 723, 969, 612, 467, 344, 810, 503, 1087, 693, 807, 619, 174, 717, 1076, 594, 29, 624, 586, 718, 336, 809, 926, 616, 441, 409, 1045, 1053, 954, 134, 451, 501, 537, 731, 533, 1010, 1006, 491, 708, 206, 688, 125, 927, 113, 770, 459, 709, 453, 390, 200, 713, 402, 906, 152, 538, 25, 379, 654, 766, 188, 267, 493, 661, 137, 776, 121, 764, 687, 430, 978, 570, 800, 47, 82, 701, 489, 959, 396, 762, 1020, 35, 216, 939, 50, 1007, 1026, 23, 948, 702, 438, 248, 987, 878, 77, 155, 278, 745, 135, 178, 998, 309, 37, 676, 759, 775, 870, 79, 1048, 539, 712, 699, 148, 412, 613, 986, 385, 34, 446, 484, 915, 207, 774, 979, 199, 657, 378, 269, 184, 371, 271, 856, 845, 394, 532, 359, 997, 673, 476, 739, 478, 236, 821, 515, 372, 665, 865, 695, 831, 965, 285, 879, 892, 958, 331, 1000, 666, 1090, 889, 374, 1091, 974, 276, 156, 179, 585, 643, 226, 935, 0, 714, 781, 790, 510, 1049, 257, 223, 447, 293, 651, 530, 825, 69, 332, 370, 1052, 548, 1065, 525, 6, 382, 291, 474, 726, 413, 55, 415, 452, 104, 196, 292, 756, 519, 1078, 215, 647, 716, 913, 779, 437, 1013, 582, 866, 748, 829, 571, 705, 299, 527, 139, 698, 277, 75, 606, 88, 1066, 838, 427, 124, 648, 321, 15, 725, 296, 486, 270, 97, 2, 286, 383, 193, 127, 808, 151, 338, 854, 483, 1080, 1050, 307, 677, 61, 108, 636, 700, 318, 684, 546, 166, 909, 84, 696, 1035, 462, 602, 481, 195, 760, 860, 123, 473, 435, 227, 517, 358, 900, 337, 495, 715, 315, 397, 785, 142, 936, 737, 622, 769, 660, 1074, 168, 393, 356, 403, 1018, 53, 833, 855, 659, 159, 180, 386, 641, 126, 87, 475, 169, 1058, 599, 844, 902, 208, 543, 918, 377, 523, 73, 595, 224, 281, 1024, 422, 1085, 490, 506, 724, 244, 231, 1064, 512, 953, 541, 952, 7, 994, 239, 662, 176, 410, 407, 730, 914, 572, 749, 141, 363, 63, 19, 841, 903, 1043, 1069, 830, 1038, 405, 11, 1023, 597, 449, 670, 678, 562, 440, 457, 683, 839, 253, 49, 161, 550, 946, 306, 182, 820, 566, 323, 32, 558, 145, 211, 1097, 300, 1012, 109, 312, 938, 144, 153, 183, 1037, 826, 743, 652, 513, 943, 157, 674, 505, 367, 921, 692, 22, 76, 847, 757, 930, 752, 804, 249, 20, 872, 250, 94, 792, 649, 1029, 874, 103, 342, 251, 310, 592, 682, 191, 799, 42, 668, 811, 232, 985, 658, 588, 92, 458, 91, 929, 738, 369, 252, 203, 314, 710, 554, 187, 265, 364, 480, 704, 632, 220, 256, 114, 466, 615, 324, 65, 64, 408, 320, 400, 460, 327, 610, 727, 10, 27, 908, 325, 667, 212, 102, 322, 488, 894, 741, 910, 555, 744, 445, 105, 842, 1056, 697, 919, 1009, 118, 165, 899, 600, 245, 897, 754, 843, 971, 947, 932, 686, 628, 1021, 735, 46, 110, 283, 1081, 893, 786, 577, 302, 993, 273, 83, 579, 229, 951, 584, 846, 824, 601, 629, 117, 349, 851, 150, 389, 74, 416, 40, 858, 813, 945, 1062, 797, 500, 732, 618, 783, 298, 904, 1054, 346, 376, 917, 995, 957, 340, 274, 794, 964, 170, 173, 136, 86, 41, 742, 66, 240, 1082, 970, 547, 703, 922, 560, 1051, 98, 818, 272, 218, 439, 347, 138, 576, 1047, 955, 160, 885, 288, 411, 626, 333, 934, 511, 981, 303, 399, 1055, 106, 504, 198, 605, 1073, 690, 587, 1044, 101, 355, 205, 387, 835, 521, 637, 720, 1086, 175, 471, 976, 222, 604, 38, 984, 8, 133, 258, 578, 426, 162, 761, 873, 317, 78, 937, 313, 48, 217, 128, 305, 755, 1030, 875, 646, 1003, 328, 822, 589, 691, 404, 31, 664, 536, 228, 461, 528, 711, 925, 645, 225, 1060, 557, 982, 694, 518, 721 ] } ================================================ FILE: datainfo/circular_motion/data_split_dict_2.json ================================================ { "test": [ 24, 688, 255, 176, 368, 371, 58, 32, 777, 734, 919, 433, 61, 901, 966, 215, 844, 250, 272, 874, 139, 534, 772, 108, 937, 896, 698, 232, 605, 1066, 50, 668, 588, 60, 217, 391, 17, 699, 154, 750, 1001, 425, 638, 832, 1032, 621, 633, 982, 1082, 340, 664, 454, 996, 935, 718, 1062, 931, 1057, 1025, 1020, 571, 1003, 511, 944, 818, 330, 1060, 741, 724, 1079, 849, 912, 372, 1052, 736, 1065, 1044, 279, 355, 665, 361, 48, 472, 483, 363, 336, 867, 778, 652, 952, 745, 56, 1090, 549, 1028, 911, 761, 1042, 805, 882, 324, 73, 434, 515, 631, 346, 739, 173, 187, 115 ], "val": [ 810, 639, 233, 1017, 82, 256, 75, 455, 731, 238, 252, 1041, 725, 520, 237, 650, 464, 98, 174, 165, 134, 34, 259, 995, 686, 143, 1049, 716, 18, 1048, 429, 620, 268, 265, 349, 924, 147, 335, 527, 498, 402, 799, 598, 530, 986, 895, 807, 458, 989, 104, 858, 323, 703, 676, 95, 230, 484, 157, 722, 636, 883, 411, 773, 270, 923, 46, 757, 619, 784, 564, 459, 315, 31, 500, 345, 292, 1098, 765, 760, 642, 630, 352, 4, 37, 155, 253, 813, 44, 603, 394, 1, 708, 535, 188, 752, 160, 958, 1033, 130, 261, 382, 21, 940, 746, 41, 25, 69, 977, 117, 84 ], "train": [ 521, 97, 212, 436, 465, 835, 410, 195, 591, 288, 988, 828, 450, 127, 42, 93, 92, 970, 873, 239, 424, 1086, 956, 310, 266, 824, 142, 717, 119, 727, 1008, 357, 871, 584, 64, 396, 817, 406, 57, 2, 681, 9, 249, 864, 512, 198, 245, 40, 473, 1089, 299, 182, 601, 379, 321, 955, 763, 486, 766, 22, 906, 744, 171, 316, 702, 467, 370, 898, 274, 820, 1015, 644, 443, 707, 13, 576, 764, 109, 838, 559, 158, 26, 670, 721, 329, 789, 798, 802, 767, 133, 431, 539, 325, 672, 485, 618, 140, 135, 438, 692, 408, 608, 997, 640, 29, 344, 572, 421, 748, 6, 327, 1039, 235, 308, 796, 635, 290, 581, 713, 178, 612, 943, 568, 866, 207, 546, 519, 917, 889, 214, 339, 111, 448, 35, 63, 248, 282, 405, 163, 629, 782, 226, 659, 850, 902, 487, 1058, 592, 540, 596, 505, 831, 801, 575, 1046, 800, 833, 963, 556, 758, 504, 347, 236, 206, 441, 682, 1085, 547, 1040, 507, 87, 1094, 580, 975, 604, 648, 569, 294, 7, 916, 607, 788, 228, 281, 573, 376, 219, 915, 241, 167, 208, 1080, 216, 529, 1075, 94, 690, 907, 542, 578, 146, 120, 968, 49, 862, 211, 680, 447, 815, 586, 1068, 488, 693, 306, 271, 269, 314, 43, 792, 797, 976, 417, 954, 191, 499, 151, 729, 854, 848, 552, 781, 307, 407, 826, 881, 985, 1097, 1069, 397, 423, 508, 168, 76, 998, 627, 439, 860, 258, 229, 105, 645, 590, 751, 377, 759, 791, 597, 80, 106, 887, 891, 386, 753, 1054, 322, 100, 809, 965, 300, 843, 661, 913, 899, 617, 283, 855, 567, 99, 842, 926, 737, 574, 634, 967, 440, 987, 948, 637, 960, 886, 416, 562, 51, 513, 1019, 662, 732, 803, 66, 356, 460, 419, 897, 152, 932, 193, 359, 945, 606, 476, 583, 660, 624, 786, 369, 466, 506, 110, 557, 403, 1012, 351, 378, 1091, 981, 276, 415, 0, 251, 427, 953, 432, 942, 770, 683, 616, 566, 877, 787, 677, 456, 257, 70, 15, 286, 71, 929, 33, 123, 689, 756, 518, 332, 827, 47, 780, 614, 790, 12, 876, 112, 632, 541, 671, 129, 845, 694, 234, 156, 201, 482, 490, 884, 348, 1022, 1021, 302, 380, 921, 738, 231, 1084, 197, 210, 723, 1016, 354, 205, 964, 1038, 517, 755, 593, 1064, 305, 918, 278, 691, 1005, 1006, 242, 675, 200, 27, 116, 492, 430, 1073, 1030, 949, 687, 59, 295, 267, 920, 613, 1009, 398, 227, 880, 890, 90, 1081, 951, 785, 865, 331, 190, 172, 23, 312, 1061, 122, 362, 126, 646, 175, 45, 442, 678, 563, 446, 128, 169, 857, 775, 719, 1070, 991, 806, 623, 754, 77, 615, 684, 145, 463, 199, 304, 714, 666, 816, 863, 277, 868, 223, 333, 841, 878, 532, 808, 469, 39, 183, 85, 179, 53, 14, 65, 861, 704, 565, 1010, 457, 973, 244, 225, 647, 461, 132, 859, 247, 202, 936, 972, 962, 343, 544, 318, 1007, 149, 747, 959, 611, 385, 930, 479, 360, 445, 365, 384, 1092, 328, 287, 350, 409, 387, 743, 555, 1083, 1011, 728, 189, 776, 194, 480, 243, 101, 162, 334, 1018, 643, 649, 137, 847, 298, 452, 749, 170, 994, 554, 730, 79, 1043, 326, 218, 373, 388, 1037, 1045, 62, 284, 971, 705, 836, 888, 983, 622, 91, 516, 153, 543, 957, 837, 892, 933, 922, 673, 319, 341, 311, 651, 264, 783, 309, 551, 742, 136, 510, 879, 946, 453, 577, 654, 550, 900, 814, 969, 641, 570, 289, 10, 1095, 301, 653, 390, 1051, 166, 55, 138, 978, 908, 1034, 1067, 81, 28, 471, 514, 273, 470, 762, 496, 829, 67, 658, 795, 3, 358, 74, 209, 478, 177, 579, 525, 667, 493, 834, 8, 141, 626, 501, 779, 594, 600, 669, 204, 905, 812, 856, 220, 395, 910, 582, 221, 655, 1047, 823, 181, 822, 1026, 846, 320, 11, 894, 1024, 412, 192, 938, 545, 934, 83, 338, 560, 875, 1072, 710, 925, 1002, 414, 280, 494, 54, 522, 999, 184, 392, 609, 30, 451, 477, 148, 587, 367, 413, 240, 1071, 161, 1096, 263, 526, 254, 528, 1074, 720, 872, 553, 159, 114, 293, 536, 992, 449, 337, 711, 903, 774, 399, 735, 131, 38, 663, 697, 88, 68, 426, 150, 558, 180, 364, 72, 1014, 589, 599, 86, 715, 695, 712, 437, 561, 1055, 610, 1078, 974, 852, 769, 303, 696, 503, 909, 1063, 489, 870, 495, 275, 383, 144, 468, 366, 297, 509, 1013, 947, 1050, 481, 674, 1035, 1059, 1087, 709, 840, 400, 656, 342, 1093, 939, 125, 853, 260, 124, 404, 353, 771, 904, 794, 733, 401, 979, 585, 118, 196, 502, 941, 224, 78, 740, 804, 811, 927, 291, 5, 885, 628, 602, 213, 474, 497, 1077, 420, 1029, 462, 16, 990, 793, 203, 313, 851, 103, 422, 839, 1053, 950, 381, 706, 296, 375, 625, 121, 531, 19, 374, 491, 821, 679, 96, 185, 830, 537, 428, 52, 1088, 595, 980, 523, 435, 444, 825, 1036, 1004, 1076, 701, 1023, 389, 657, 548, 317, 914, 475, 685, 533, 984, 222, 107, 893, 869, 186, 20, 102, 928, 246, 89, 1056, 524, 113, 164, 418, 393, 36, 1027, 961, 768, 538, 285, 1000, 700, 262, 993, 726, 1031, 819 ] } ================================================ FILE: datainfo/circular_motion/data_split_dict_3.json ================================================ { "test": [ 887, 659, 395, 532, 890, 471, 385, 386, 882, 918, 141, 981, 368, 321, 347, 888, 1003, 43, 706, 159, 269, 625, 298, 417, 994, 202, 971, 32, 548, 614, 110, 78, 7, 317, 1065, 483, 571, 677, 773, 92, 90, 243, 850, 1082, 601, 41, 1085, 840, 136, 704, 181, 990, 987, 129, 254, 583, 546, 432, 213, 668, 334, 572, 58, 689, 475, 834, 718, 790, 1038, 862, 1076, 893, 528, 444, 1013, 278, 73, 199, 748, 274, 910, 808, 874, 793, 968, 551, 63, 616, 87, 326, 131, 31, 798, 1071, 310, 474, 308, 813, 975, 963, 392, 479, 531, 960, 26, 134, 970, 757, 267, 487 ], "val": [ 223, 897, 759, 344, 81, 173, 876, 829, 448, 229, 853, 691, 341, 329, 930, 104, 99, 714, 665, 445, 694, 191, 335, 244, 275, 823, 669, 227, 512, 1022, 1094, 752, 700, 582, 27, 832, 1056, 107, 996, 806, 307, 270, 988, 864, 936, 776, 320, 189, 372, 1039, 327, 615, 606, 305, 467, 643, 257, 378, 21, 692, 62, 974, 22, 501, 755, 285, 723, 623, 950, 939, 695, 361, 477, 340, 642, 648, 61, 1037, 647, 603, 630, 995, 20, 322, 593, 425, 807, 11, 1005, 561, 1083, 533, 264, 447, 1035, 958, 1030, 732, 737, 649, 441, 277, 519, 830, 1088, 635, 105, 1050, 697, 609 ], "train": [ 401, 148, 464, 711, 957, 1084, 420, 1007, 318, 1074, 579, 972, 434, 2, 770, 220, 1010, 540, 192, 740, 857, 788, 545, 408, 352, 163, 469, 118, 758, 534, 506, 587, 504, 396, 863, 760, 231, 794, 1059, 1049, 768, 655, 456, 279, 313, 459, 769, 371, 80, 747, 66, 67, 525, 804, 521, 144, 1081, 894, 1060, 1066, 682, 439, 476, 142, 576, 620, 38, 672, 909, 345, 328, 1019, 193, 1087, 713, 179, 413, 156, 607, 640, 898, 1046, 511, 1027, 253, 558, 135, 781, 1098, 599, 945, 1073, 157, 905, 221, 197, 177, 59, 219, 154, 178, 486, 421, 727, 64, 53, 402, 1097, 355, 751, 228, 646, 631, 852, 452, 627, 1052, 309, 292, 462, 325, 784, 555, 83, 225, 149, 350, 1025, 1031, 468, 427, 969, 49, 846, 836, 262, 473, 1012, 520, 12, 573, 60, 212, 1048, 973, 820, 671, 374, 460, 249, 859, 171, 172, 198, 218, 407, 746, 908, 502, 112, 1036, 450, 594, 1089, 999, 931, 809, 301, 137, 236, 819, 639, 753, 843, 1002, 1069, 75, 207, 100, 214, 300, 346, 720, 8, 578, 89, 188, 272, 1070, 418, 84, 725, 167, 238, 875, 800, 470, 390, 1055, 835, 211, 312, 217, 656, 0, 405, 151, 941, 1026, 174, 375, 586, 304, 814, 315, 265, 855, 1072, 929, 143, 565, 778, 485, 779, 337, 201, 562, 399, 186, 709, 500, 608, 1080, 362, 29, 743, 621, 42, 48, 121, 949, 1040, 696, 266, 557, 685, 6, 690, 624, 10, 1044, 955, 952, 719, 424, 16, 680, 162, 415, 287, 575, 821, 411, 363, 780, 503, 754, 71, 14, 86, 935, 702, 394, 299, 717, 721, 921, 589, 977, 693, 658, 889, 916, 496, 114, 296, 185, 938, 842, 712, 792, 772, 19, 233, 454, 224, 1018, 896, 789, 966, 155, 698, 106, 216, 687, 872, 242, 449, 404, 82, 535, 480, 209, 998, 482, 629, 248, 113, 1041, 705, 1016, 673, 775, 40, 17, 370, 1034, 68, 841, 168, 165, 507, 9, 801, 762, 39, 259, 364, 316, 633, 803, 25, 319, 666, 239, 708, 914, 728, 145, 1057, 412, 815, 472, 518, 699, 108, 619, 397, 735, 564, 652, 684, 745, 251, 119, 338, 951, 943, 379, 97, 797, 387, 559, 1061, 466, 458, 351, 844, 1032, 324, 158, 679, 282, 161, 443, 208, 828, 15, 899, 605, 526, 103, 670, 653, 1045, 543, 451, 517, 597, 724, 976, 303, 65, 516, 1000, 574, 343, 1047, 290, 96, 357, 674, 416, 654, 667, 288, 617, 1058, 1095, 513, 437, 77, 57, 764, 498, 457, 736, 414, 822, 993, 206, 170, 928, 263, 393, 356, 681, 831, 628, 774, 937, 816, 913, 196, 138, 956, 750, 660, 200, 923, 294, 1064, 283, 510, 1028, 366, 707, 810, 947, 650, 595, 610, 246, 234, 95, 331, 153, 284, 854, 932, 638, 1001, 783, 169, 481, 72, 286, 55, 1075, 817, 901, 911, 455, 926, 91, 76, 741, 373, 489, 756, 777, 380, 505, 927, 967, 596, 367, 536, 180, 560, 252, 891, 1051, 1009, 585, 289, 538, 552, 946, 1024, 917, 644, 730, 591, 877, 377, 686, 895, 256, 989, 115, 442, 1020, 281, 924, 676, 786, 400, 742, 215, 984, 497, 388, 446, 339, 342, 30, 824, 997, 463, 365, 839, 851, 920, 664, 120, 866, 93, 881, 922, 925, 641, 884, 261, 845, 900, 140, 1015, 194, 867, 260, 865, 749, 715, 235, 478, 811, 767, 74, 802, 688, 570, 1090, 210, 491, 622, 147, 1067, 892, 645, 731, 176, 205, 391, 490, 1068, 23, 314, 222, 598, 791, 302, 550, 237, 703, 744, 150, 965, 273, 139, 190, 733, 453, 612, 611, 509, 98, 166, 152, 410, 837, 495, 529, 983, 878, 964, 164, 422, 240, 907, 661, 85, 701, 569, 358, 406, 613, 376, 785, 403, 959, 465, 255, 636, 247, 79, 184, 954, 332, 563, 782, 566, 102, 567, 109, 796, 541, 523, 663, 37, 1029, 306, 592, 291, 493, 117, 1093, 683, 787, 24, 183, 384, 101, 675, 94, 761, 45, 886, 241, 146, 1096, 5, 1033, 985, 544, 577, 1078, 111, 1092, 871, 795, 726, 860, 124, 268, 982, 383, 333, 28, 133, 436, 1004, 435, 738, 604, 276, 1017, 879, 849, 18, 716, 428, 430, 739, 508, 953, 381, 915, 979, 833, 44, 651, 722, 554, 880, 618, 1063, 382, 1006, 182, 580, 883, 389, 3, 568, 130, 126, 438, 336, 870, 271, 369, 311, 600, 398, 662, 991, 359, 1023, 869, 160, 323, 942, 514, 527, 88, 729, 827, 494, 70, 766, 127, 47, 1042, 56, 1, 330, 484, 52, 433, 992, 539, 632, 903, 1043, 848, 51, 409, 584, 934, 499, 1021, 1079, 280, 245, 799, 1086, 547, 125, 1014, 226, 360, 948, 1054, 553, 258, 128, 818, 116, 626, 54, 838, 549, 537, 961, 904, 902, 919, 515, 175, 873, 861, 492, 13, 50, 590, 440, 203, 524, 710, 35, 46, 250, 122, 293, 602, 69, 232, 349, 556, 295, 1091, 765, 678, 771, 933, 763, 33, 906, 734, 986, 1062, 522, 637, 488, 4, 204, 1008, 423, 36, 419, 581, 429, 297, 426, 978, 354, 1053, 885, 812, 530, 1011, 431, 132, 940, 353, 634, 825, 1077, 657, 847, 868, 348, 944, 187, 588, 858, 856, 826, 962, 195, 542, 34, 123, 805, 230, 980, 461, 912 ] } ================================================ FILE: datainfo/collect/circular_motion/make_data.py ================================================ import numpy as np import os from tqdm import tqdm import shutil from PIL import Image, ImageDraw def mkdir(folder): if os.path.exists(folder): shutil.rmtree(folder) os.makedirs(folder) # background and object properties img_size = (128, 128) bg_color = 'gray' obj_size = 8 obj_shape = 'ball' obj_color ='blue' def coord2img(x, y): img = Image.new('RGB', img_size, bg_color) draw = ImageDraw.Draw(img) pos_x = int(x * img_size[0]) pos_y = int((1-y) * img_size[1]) pos = (pos_x-obj_size, pos_y-obj_size, pos_x+obj_size, pos_y+obj_size) if obj_shape == 'square': draw.rectangle(pos, fill=obj_color) elif obj_shape == 'ball': draw.ellipse(pos, fill=obj_color) else: assert False return img data_filepath = './circular_motion' num_exp = 1100 num_frames = 60 # data_filepath = './circular_motion_long' # num_exp = 1 # num_frames = 60*20 center = (0.5, 0.5) radius = 0.3 np.random.seed(0) for p_exp in tqdm(range(num_exp)): theta = np.random.uniform(0, 2*np.pi) theta_d = np.random.choice((-1, 1)) * np.random.uniform(np.pi/30, np.pi/10) seq_filepath = os.path.join(data_filepath, str(p_exp)) mkdir(seq_filepath) for p_frame in range(num_frames): x = center[0] + radius * np.cos(theta) y = center[1] + radius * np.sin(theta) img = coord2img(x, y) theta += theta_d img.save(os.path.join(seq_filepath, str(p_frame)+'.png')) ================================================ FILE: datainfo/collect/double_pendulum/README.md ================================================ ## Double Pendulum Data Collection Suppose the raw videos recording the double pendulum motion are stored in ```./visphysics_data/double_pendulum_raw_videos``` in the current folder. Run the following commands. The collected data are saved in ```./visphysics_data/double_pendulum```. 1. python convert_video.py -f 60 (convert each video to a sequence of images with sampling frequency 60fps) 2. python equalize_background.py -r 215 -g 205 -b 192 (equalize the background color) 3. python split_data.py (split each sequence into subsequences with fixed length 60) ================================================ FILE: datainfo/collect/double_pendulum/convert_video.py ================================================ import cv2 import os import shutil from tqdm import tqdm import argparse import random import time ## This script is for the purposes of extracting frames from given set of input videos ## at desired fps. def mkdir(folder): if os.path.exists(folder): shutil.rmtree(folder) os.makedirs(folder) def get_arguments(): ap = argparse.ArgumentParser() ap.add_argument('-f', '--fps', required=True, type=int, help='frames per second at which it is desirable to get output data') ap.add_argument('-s', '--seed', default=0, type=int, help='random seed for shuffling') args = ap.parse_args() return args def write_frames(video_path, frame_path, fps, seed): # get list of all videos in video_path listing = os.listdir(video_path) # remove .DS_Store files if listing.__contains__('.DS_Store'): listing.remove('.DS_Store') # random suffle random.seed(seed) random.shuffle(listing) num_videos = len(listing) for (idx, filename) in enumerate(listing): print(idx+1, 'out of ', num_videos) cap = cv2.VideoCapture(os.path.join(video_path, filename)) video_fps = cap.get(cv2.CAP_PROP_FPS) print('video fps: ', video_fps) data_path = os.path.join(frame_path, str(idx)) mkdir(data_path) p_frame = 0 # frame count skip = round(video_fps / fps) # sampling frequency while(True): # read current frame ret, frame = cap.read() if ret == False: break # manual crop frame = frame[0:620, 371:991] # downsize frame = cv2.resize(frame, (128, 128), interpolation = cv2.INTER_AREA) if p_frame % skip == 0: cv2.imwrite(os.path.join(data_path, str(int(p_frame/skip))+'.png'), frame) p_frame += 1 cap.release() cv2.destroyAllWindows() time.sleep(1.0) # short delay for cleaning purposes video_path = "./visphysics_data/double_pendulum_raw_videos" frame_path = "./visphysics_data/double_pendulum_raw_data" args = get_arguments() write_frames(video_path, frame_path, args.fps, args.seed) ================================================ FILE: datainfo/collect/double_pendulum/equalize_background.py ================================================ import cv2 import os import shutil from tqdm import tqdm import argparse import numpy as np ''' This script equalizes backgrounds of all frames with a uniform background. The background color is user-specified or computed by averaging background pixels of sampled frames from the data. ''' def mkdir(folder): if os.path.exists(folder): shutil.rmtree(folder) os.makedirs(folder) def get_arguments(): ap = argparse.ArgumentParser() ap.add_argument('-b', '--channel0', help='b channel') ap.add_argument('-g', '--channel1', help='g channel') ap.add_argument('-r', '--channel2', help='r channel') args = ap.parse_args() return args raw_filepath = "./visphysics_data/double_pendulum_raw_data" data_filepath = "./visphysics_data/double_pendulum_background_corrected_data" num_exps = len(os.listdir(raw_filepath)) bg = np.zeros(3) args = get_arguments() if args.channel0 is not None and args.channel1 is not None and args.channel2 is not None: # user-specified background color bg[0] = args.channel0 bg[1] = args.channel1 bg[2] = args.channel2 else: # uniform background by averaging sampled pixels percent = .05 # proportion of sampled frames cnt = 0 print('Computing uniform background...') for p_exp in tqdm(range(num_exps)): seq_filepath = os.path.join(raw_filepath, str(p_exp)) num_frames = len(os.listdir(seq_filepath)) num_samples = int(percent * num_frames) for p_frame in range(num_samples): # read each frame frame_path = os.path.join(seq_filepath, str(p_frame)+'.png') frame = cv2.imread(frame_path) # record all background pixels for i in range(frame.shape[0]): for j in range(frame.shape[1]): # manual threshold if (frame[i,j,2] > 130): bg += frame[i,j] cnt += 1 # average over all sampled pixels bg = bg / cnt print('Uniform Background R: ', bg[2], ' G: ', bg[1], ' B: ', bg[0]) # apply the background to all data print('Applying the background...') for p_exp in tqdm(range(num_exps)): seq_filepath = os.path.join(raw_filepath, str(p_exp)) data_seq_filepath = os.path.join(data_filepath, str(p_exp)) mkdir(data_seq_filepath) num_frames = len(os.listdir(seq_filepath)) for p_frame in range(num_frames): frame_path = os.path.join(seq_filepath, str(p_frame)+'.png') frame = cv2.imread(frame_path) for i in range(frame.shape[0]): for j in range(frame.shape[1]): # manual threshold if (frame[i,j,2] > 130): frame[i,j] = bg # write to data cv2.imwrite(os.path.join(data_seq_filepath, str(p_frame)+'.png'), frame) ================================================ FILE: datainfo/collect/double_pendulum/split_data.py ================================================ import os import shutil from tqdm import tqdm import argparse ''' This script splits each sequence, which corresponds to a single experiment, into subsequences with fixed length. ''' def mkdir(folder): if os.path.exists(folder): shutil.rmtree(folder) os.makedirs(folder) def get_arguments(): ap = argparse.ArgumentParser() ap.add_argument('-l', '--len', type=int, default=60, help='length of each subsequence') args = ap.parse_args() return args src_filepath = "./visphysics_data/double_pendulum_background_corrected_data" dst_filepath = "./visphysics_data/double_pendulum" args = get_arguments() num_exps = len(os.listdir(src_filepath)) subseq_cnt = 0 frame_cnt = 0 mkdir(os.path.join(dst_filepath, str(subseq_cnt))) for p_exp in tqdm(range(num_exps)): src_seq_filepath = os.path.join(src_filepath, str(p_exp)) num_frames = len(os.listdir(src_seq_filepath)) # ignore frames at the end of the sequence num_frames = int(num_frames/args.len) * args.len for p_frame in range(num_frames): src_path = os.path.join(src_seq_filepath, str(p_frame)+'.png') dst_path = os.path.join(dst_filepath, str(subseq_cnt), str(frame_cnt)+'.png') shutil.copyfile(src_path, dst_path) frame_cnt += 1 if frame_cnt == args.len: subseq_cnt += 1 frame_cnt = 0 mkdir(os.path.join(dst_filepath, str(subseq_cnt))) # remove the last empty folder shutil.rmtree(os.path.join(dst_filepath, str(subseq_cnt))) ================================================ FILE: datainfo/collect/elastic_pendulum/make_data.py ================================================ import os import shutil import numpy as np from tqdm import tqdm from PIL import Image, ImageDraw from scipy import linalg from scipy.integrate import solve_ivp def mkdir(folder): if os.path.exists(folder): shutil.rmtree(folder) os.makedirs(folder) def engine(rng, num_frm, fps=60): # physical parameters L_0 = 0.205 # elastic pendulum unstretched length (m) L = 0.179 # rigid pendulum rod length (m) w = 0.038 # rigid pendulum rod width (m) m = 0.110 # rigid pendulum mass (kg) g = 9.81 # gravitational acceleration (m/s^2) k = 40.0 # elastic constant (kg/s^2) dt = 1.0 / fps t_eval = np.arange(num_frm) * dt # solve equations of motion # y = [theta_1, theta_2, z, vel_theta_1, vel_theta_2, vel_z] def f(t, y): La = L_0 + y[2] Lb = 0.5 * L I = (1. / 12.) * m * (w ** 2 + L ** 2) Jc = L * np.cos(y[0] - y[1]) Js = L * np.sin(y[0] - y[1]) A = np.array([[La**2, 0.5*La*Jc, 0], [0.5*La*Jc, Lb**2+I/m, 0.5*Js], [0, 0.5*Js, 1]]) b = np.zeros(3) b[0] = -0.5*La*Js*y[4]**2 - 2*La*y[3]*y[5] - g*La*np.sin(y[0]) b[1] = 0.5*La*Js*y[3]**2 - Jc*y[3]*y[5] - g*Lb*np.sin(y[1]) b[2] = La*y[3]**2 + 0.5*Jc*y[4]**2 + g*np.cos(y[0]) - (k/m)*y[2] sol = linalg.solve(A, b) return [y[3], y[4], y[5], sol[0], sol[1], sol[2]] # run until the drawn pendulums are inside the image rej = True while rej: initial_state = [rng.uniform(0, 2*np.pi), rng.uniform(0, 2*np.pi), rng.uniform(-0.04, 0.04), rng.uniform(-6, 6), rng.uniform(-10, 10), 0] sol = solve_ivp(f, [t_eval[0], t_eval[-1]], initial_state, t_eval=t_eval, rtol=1e-6) states = sol.y.T lim_x = np.abs((L_0+states[:, 2])*np.sin(states[:, 0]) + L*np.sin(states[:, 1])) - (L_0 + L) lim_y = np.abs((L_0+states[:, 2])*np.cos(states[:, 0]) + L*np.cos(states[:, 1])) - (L_0 + L) rej = (np.max(lim_x) > 0.16) | (np.max(lim_y) > 0.16) | (np.min(states[:, 2]) < -0.08) return states def draw_rect(im, col, top_x, top_y, w, h, theta): x1 = top_x - w * np.cos(theta) / 2 y1 = top_y + w * np.sin(theta) / 2 x2 = x1 + w * np.cos(theta) y2 = y1 - w * np.sin(theta) x3 = x2 + h * np.sin(theta) y3 = y2 + h * np.cos(theta) x4 = x3 - w * np.cos(theta) y4 = y3 + w * np.sin(theta) pts = [(x1, y1), (x2, y2), (x3, y3), (x4, y4)] draw = ImageDraw.Draw(im) draw.polygon(pts, fill=col) def render(theta_1, theta_2, z): bg_color = (215, 205, 192) pd1_color = (63, 66, 85) pd2_color = (17, 93, 234) im = Image.new('RGB', (1600, 1600), bg_color) center = (800, 800) w1, w2, h1, h2 = 90, 90, 300, 250 L_0 = 0.205 h1 *= (1 + z / L_0) pd1_end_x = center[0] + (h1-35) * np.sin(theta_1) pd1_end_y = center[1] + (h1-35) * np.cos(theta_1) # pd1 may hide pd2 draw_rect(im, pd2_color, pd1_end_x, pd1_end_y, w2, h2, theta_2) draw_rect(im, pd1_color, center[0], center[1], w1, h1, theta_1) im = im.resize((128, 128)) return im def make_data(data_filepath, num_seq, num_frm, seed=0): mkdir(data_filepath) rng = np.random.default_rng(seed) states = np.zeros((num_seq, num_frm, 6)) for n in tqdm(range(num_seq)): seq_filepath = os.path.join(data_filepath, str(n)) mkdir(seq_filepath) states[n, :, :] = engine(rng, num_frm) for k in range(num_frm): im = render(states[n, k, 0], states[n, k, 1], states[n, k, 2]) im.save(os.path.join(seq_filepath, str(k)+'.png')) np.save(os.path.join(data_filepath, 'states.npy'), states) if __name__ == '__main__': data_filepath = '/data/kuang/visphysics_data/elastic_pendulum' make_data(data_filepath, num_seq=1200, num_frm=60) ================================================ FILE: datainfo/collect/fire/split_data.py ================================================ import os import shutil from tqdm import tqdm import argparse ''' This script splits each sequence, which corresponds to a single experiment, into subsequences with fixed length. ''' def mkdir(folder): if os.path.exists(folder): shutil.rmtree(folder) os.makedirs(folder) def get_arguments(): ap = argparse.ArgumentParser() ap.add_argument('-l', '--len', type=int, default=60, help='length of each subsequence') args = ap.parse_args() return args src_train_filepath = "./visphysics_data/fire/train" src_test_filepath = "./visphysics_data/fire/test" dst_filepath = "./visphysics_data/fire_60" args = get_arguments() num_train_exps = len(os.listdir(src_train_filepath)) num_test_exps = len(os.listdir(src_test_filepath)) subseq_cnt = 0 frame_cnt = 0 mkdir(os.path.join(dst_filepath, str(subseq_cnt))) # train for p_exp in tqdm(range(num_train_exps)): src_seq_filepath = os.path.join(src_train_filepath, str(p_exp)) num_frames = len(os.listdir(src_seq_filepath)) # ignore frames at the end of the sequence num_frames = int(num_frames/args.len) * args.len for p_frame in range(num_frames): src_path = os.path.join(src_seq_filepath, str(p_frame)+'.jpg') dst_path = os.path.join(dst_filepath, str(subseq_cnt), str(frame_cnt)+'.jpg') shutil.copyfile(src_path, dst_path) frame_cnt += 1 if frame_cnt == args.len: subseq_cnt += 1 frame_cnt = 0 mkdir(os.path.join(dst_filepath, str(subseq_cnt))) # test for p_exp in tqdm(range(num_test_exps)): src_seq_filepath = os.path.join(src_test_filepath, str(p_exp)) num_frames = len(os.listdir(src_seq_filepath)) # ignore frames at the end of the sequence num_frames = int(num_frames/args.len) * args.len for p_frame in range(num_frames): src_path = os.path.join(src_seq_filepath, str(p_frame)+'.png') dst_path = os.path.join(dst_filepath, str(subseq_cnt), str(frame_cnt)+'.png') shutil.copyfile(src_path, dst_path) frame_cnt += 1 if frame_cnt == args.len: subseq_cnt += 1 frame_cnt = 0 mkdir(os.path.join(dst_filepath, str(subseq_cnt))) # remove the last empty folder shutil.rmtree(os.path.join(dst_filepath, str(subseq_cnt))) ================================================ FILE: datainfo/collect/lava_lamp/split_data.py ================================================ import os import shutil from tqdm import tqdm import argparse ''' This script splits each sequence, which corresponds to a single experiment, into subsequences with fixed length. ''' def mkdir(folder): if os.path.exists(folder): shutil.rmtree(folder) os.makedirs(folder) def get_arguments(): ap = argparse.ArgumentParser() ap.add_argument('-l', '--len', type=int, default=60, help='length of each subsequence') args = ap.parse_args() return args src_train_filepath = "./visphysics_data/lavalamp/raw_data_train" src_test_filepath = "./visphysics_data/lavalamp/raw_data_test" dst_filepath = "./visphysics_data/lavalamp_60" args = get_arguments() num_train_exps = len(os.listdir(src_train_filepath)) num_test_exps = len(os.listdir(src_test_filepath)) subseq_cnt = 0 frame_cnt = 0 mkdir(os.path.join(dst_filepath, str(subseq_cnt))) # train src_seq_filepath = src_train_filepath num_frames = len(os.listdir(src_seq_filepath)) # ignore frames at the end of the sequence num_frames = int(num_frames/args.len) * args.len for p_frame in range(num_frames): src_path = os.path.join(src_seq_filepath, str(p_frame)+'.jpg') dst_path = os.path.join(dst_filepath, str(subseq_cnt), str(frame_cnt)+'.jpg') shutil.copyfile(src_path, dst_path) frame_cnt += 1 if frame_cnt == args.len: subseq_cnt += 1 frame_cnt = 0 mkdir(os.path.join(dst_filepath, str(subseq_cnt))) # test src_seq_filepath = src_test_filepath num_frames = len(os.listdir(src_seq_filepath)) # ignore frames at the end of the sequence num_frames = int(num_frames/args.len) * args.len for p_frame in range(num_frames): src_path = os.path.join(src_seq_filepath, str(p_frame)+'.png') dst_path = os.path.join(dst_filepath, str(subseq_cnt), str(frame_cnt)+'.png') shutil.copyfile(src_path, dst_path) frame_cnt += 1 if frame_cnt == args.len: subseq_cnt += 1 frame_cnt = 0 mkdir(os.path.join(dst_filepath, str(subseq_cnt))) # remove the last empty folder shutil.rmtree(os.path.join(dst_filepath, str(subseq_cnt))) ================================================ FILE: datainfo/collect/reaction_diffusion/make_data.py ================================================ import os import shutil from tqdm import tqdm import numpy as np import scipy.io as sio from PIL import Image from matplotlib import cm def mkdir(folder): if os.path.exists(folder): shutil.rmtree(folder) os.makedirs(folder) def make_data(data_filepath, num_frm): data = sio.loadmat('reaction_diffusion.mat') mkdir(data_filepath) num_seq = int(data['t'].size / num_frm) u_min, u_max = -1, 1 for n in tqdm(range(num_seq)): seq_filepath = os.path.join(data_filepath, str(n)) mkdir(seq_filepath) for k in range(num_frm): u = data['uf'][:, :, n*num_frm+k] u = (u - u_min) / (u_max - u_min) u = cm.viridis(u)[:,:,:3] u = np.uint8(u*255) im = Image.fromarray(u) im.save(os.path.join(seq_filepath, str(k)+'.png')) if __name__ == '__main__': data_filepath = 'reaction_diffusion' make_data(data_filepath, 100) ================================================ FILE: datainfo/collect/reaction_diffusion/reaction_diffusion.m ================================================ clear all; close all; clc % lambda-omega reaction-diffusion system % u_t = lam(A) u - ome(A) v + d1*(u_xx + u_yy) = 0 % v_t = ome(A) u + lam(A) v + d2*(v_xx + v_yy) = 0 % % A^2 = u^2 + v^2 and % lam(A) = 1 - A^2 % ome(A) = -beta*A^2 t=0:0.2:2000.2; d1=0.1; d2=0.1; beta=1.0; L=20; n=128; N=n*n; x2=linspace(-L/2,L/2,n+1); x=x2(1:n); y=x; kx=(2*pi/L)*[0:(n/2-1) -n/2:-1]; ky=kx; % INITIAL CONDITIONS [X,Y]=meshgrid(x,y); [KX,KY]=meshgrid(kx,ky); K2=KX.^2+KY.^2; K22=reshape(K2,N,1); m=1; % number of spirals f = exp(-.01*(X.^2+Y.^2)); % circular mask to get rid of boundaries u_ini=tanh(sqrt(X.^2+Y.^2)).*cos(m*angle(X+i*Y)-(sqrt(X.^2+Y.^2))); v_ini=tanh(sqrt(X.^2+Y.^2)).*sin(m*angle(X+i*Y)-(sqrt(X.^2+Y.^2))); % uf(:,:,1)=f.*u(:,:,1); % vf(:,:,1)=f.*v(:,:,1); % REACTION-DIFFUSION uvt=[reshape(fft2(u_ini),1,N) reshape(fft2(v_ini),1,N)].'; uvt_rhs = reaction_diffusion_rhs(t(1),uvt,[],K22,d1,d2,beta,n,N); % du(:,:,1) = real(ifft2(reshape(uvt_rhs(1:N).',n,n))); % dv(:,:,1)= real(ifft2(reshape(uvt_rhs((N+1):(2*N)).',n,n))); [t,uvsol]=ode45('reaction_diffusion_rhs',t,uvt,[],K22,d1,d2,beta,n,N); u = zeros(length(x),length(y),length(t)); v = zeros(length(x),length(y),length(t)); uf = zeros(length(x), length(y), length(t)); vf = zeros(length(x), length(y), length(t)); du = zeros(length(x),length(y),length(t)); dv = zeros(length(x),length(y),length(t)); duf = zeros(length(x),length(y),length(t)); dvf = zeros(length(x),length(y),length(t)); for j=1:length(t)-1 ut=reshape((uvsol(j,1:N).'),n,n); vt=reshape((uvsol(j,(N+1):(2*N)).'),n,n); u(:,:,j+1)=real(ifft2(ut)); v(:,:,j+1)=real(ifft2(vt)); uvt_rhs = reaction_diffusion_rhs(t(j+1),uvsol(j,1:end).',[],K22,d1,d2,beta,n,N); du(:,:,j+1)= real(ifft2(reshape(uvt_rhs(1:N).',n,n))); dv(:,:,j+1)= real(ifft2(reshape(uvt_rhs((N+1):(2*N)).',n,n))); uf(:,:,j+1)=f.*u(:,:,j+1); vf(:,:,j+1)=f.*v(:,:,j+1); duf(:,:,j+1)=f.*du(:,:,j+1); dvf(:,:,j+1)=f.*dv(:,:,j+1); % figure(1) % pcolor(x,y,v(:,:,j+1)); shading interp; colormap(hot); colorbar; drawnow; end t = t(3:end); uf = uf(:,:,3:end); vf = vf(:,:,3:end); duf = duf(:,:,3:end); dvf = dvf(:,:,3:end); save('reaction_diffusion.mat','t','x','y','uf') %% % load reaction_diffusion_big % pcolor(x,y,u(:,:,end)); shading interp; colormap(hot) ================================================ FILE: datainfo/collect/reaction_diffusion/reaction_diffusion_rhs.m ================================================ function rhs=reaction_diffusion_rhs(t,uvt,dummy,K22,d1,d2,beta,n,N); % Calculate u and v terms ut=reshape((uvt(1:N)),n,n); vt=reshape((uvt((N+1):(2*N))),n,n); u=real(ifft2(ut)); v=real(ifft2(vt)); % Reaction Terms u3=u.^3; v3=v.^3; u2v=(u.^2).*v; uv2=u.*(v.^2); utrhs=reshape((fft2(u-u3-uv2+beta*u2v+beta*v3)),N,1); vtrhs=reshape((fft2(v-u2v-v3-beta*u3-beta*uv2)),N,1); rhs=[-d1*K22.*uvt(1:N)+utrhs -d2*K22.*uvt(N+1:end)+vtrhs]; ================================================ FILE: datainfo/collect/single_pendulum/make_data.py ================================================ import os import shutil import numpy as np from tqdm import tqdm from PIL import Image from scipy.integrate import solve_ivp def mkdir(folder): if os.path.exists(folder): shutil.rmtree(folder) os.makedirs(folder) def engine(rng, num_frm, fps=60): # parameters m = 1.0 l = 0.5 g = 9.81 dt = 1.0 / fps t_eval = np.arange(num_frm) * dt # solve equations of motion # y = [theta, vel_theta] f = lambda t, y: [y[1], -3*g/(2*l) * np.sin(y[0])] initial_state = [rng.uniform(0, 2*np.pi), rng.uniform(-10, 10)] sol = solve_ivp(f, [t_eval[0], t_eval[-1]], initial_state, t_eval=t_eval, rtol=1e-6) states = sol.y.T return states def render(theta): bg_color = (215, 205, 192) im = Image.new('RGB', (800, 800), bg_color) pendulum_im = Image.open('./pendulum.png') im.paste(pendulum_im, (363, 400)) im = im.rotate(theta*180/np.pi, fillcolor=bg_color) im = im.resize((128,128)) return im def make_data(data_filepath, num_seq, num_frm, seed=0): mkdir(data_filepath) rng = np.random.default_rng(seed) states = np.zeros((num_seq, num_frm, 2)) for n in tqdm(range(num_seq)): seq_filepath = os.path.join(data_filepath, str(n)) mkdir(seq_filepath) states[n, :, :] = engine(rng, num_frm) for k in range(num_frm): im = render(states[n, k, 0]) im.save(os.path.join(seq_filepath, str(k)+'.png')) np.save(os.path.join(data_filepath, 'states.npy'), states) if __name__ == '__main__': data_filepath = '/data/kuang/visphysics_data/single_pendulum' make_data(data_filepath, num_seq=1200, num_frm=60) ================================================ FILE: datainfo/collect/utils/sort_vids.py ================================================ import os import shutil import numpy as np from tqdm import tqdm def mkdir(folder): if os.path.exists(folder): shutil.rmtree(folder) os.makedirs(folder) data_folder = '/home/cml/Downloads/visphysics_data' object_lst = ['air_dancer', 'fire', 'swingstick_magnetic', 'swingstick_non_magnetic'] num_train_vids = [22, 1, 18, 68] num_test_vids = [5, 1, 5, 17] new_data_folder= '/home/cml/Downloads/visphysics_data_new' for i in tqdm(range(len(object_lst))): obj = object_lst[i] old_folder = os.path.join(data_folder, obj) new_folder = os.path.join(new_data_folder, obj) # train num_train_frames = len(os.listdir(os.path.join(old_folder, 'raw_data_train'))) ids = np.array(list(range(num_train_frames))) ids = np.split(ids, num_train_vids[i]) for j in tqdm(range(num_train_vids[i])): new_vid_folder = os.path.join(new_folder, 'train', str(j)) mkdir(new_vid_folder) this_ids = ids[j] k = 0 for p_idx in this_ids: src = os.path.join(old_folder, 'raw_data_train', str(p_idx) + '.jpg') dest = os.path.join(new_vid_folder, str(k) + '.jpg') shutil.copy(src, dest) k = k + 1 # test num_test_frames = len(os.listdir(os.path.join(old_folder, 'raw_data_test'))) ids = np.array(list(range(num_test_frames))) ids = np.split(ids, num_test_vids[i]) for j in tqdm(range(num_test_vids[i])): new_vid_folder = os.path.join(new_folder, 'test', str(j)) mkdir(new_vid_folder) this_ids = ids[j] k = 0 for p_idx in this_ids: src = os.path.join(old_folder, 'raw_data_test', str(p_idx) + '.png') dest = os.path.join(new_vid_folder, str(k) + '.png') shutil.copy(src, dest) k = k + 1 ================================================ FILE: datainfo/double_pendulum/data_split_dict_1.json ================================================ { "test": [ 524, 1032, 960, 907, 977, 840, 877, 802, 392, 5, 746, 980, 623, 1068, 675, 1098, 931, 470, 361, 591, 867, 975, 352, 526, 414, 237, 561, 880, 942, 552, 1059, 789, 12, 1005, 464, 345, 348, 806, 631, 89, 961, 60, 1002, 758, 1046, 335, 221, 898, 177, 767, 751, 354, 848, 827, 497, 983, 70, 805, 1034, 1022, 581, 621, 388, 1039, 1072, 1025, 681, 247, 607, 380, 204, 852, 44, 593, 941, 448, 472, 707, 477, 1015, 896, 454, 59, 864, 443, 780, 18, 52, 45, 62, 650, 209, 468, 545, 912, 4, 886, 798, 58, 999, 192, 429, 777, 967, 920, 1014, 241, 522, 129, 275 ], "val": [ 456, 164, 736, 36, 149, 406, 1075, 230, 21, 1017, 442, 620, 214, 1096, 747, 259, 111, 264, 1089, 815, 431, 351, 395, 319, 24, 116, 485, 508, 329, 719, 465, 301, 728, 663, 279, 672, 172, 540, 1063, 163, 171, 71, 297, 1011, 189, 639, 944, 112, 1004, 255, 287, 773, 772, 14, 463, 17, 888, 85, 72, 689, 861, 33, 261, 836, 871, 816, 564, 817, 93, 881, 185, 598, 563, 181, 1079, 235, 823, 28, 614, 469, 339, 627, 1033, 638, 553, 551, 1, 1040, 424, 365, 832, 496, 423, 516, 963, 1019, 567, 583, 373, 890, 492, 57, 972, 436, 210, 574, 796, 531, 132, 828 ], "train": [ 391, 479, 51, 1077, 882, 625, 268, 197, 433, 635, 246, 685, 640, 803, 131, 398, 1070, 580, 924, 863, 962, 238, 39, 763, 653, 455, 608, 869, 680, 768, 634, 107, 514, 509, 884, 417, 219, 146, 311, 630, 996, 290, 80, 911, 366, 353, 130, 814, 1001, 765, 100, 565, 487, 498, 158, 147, 862, 887, 13, 529, 295, 444, 916, 795, 722, 260, 1016, 950, 609, 284, 895, 542, 535, 381, 450, 294, 750, 502, 494, 326, 556, 67, 1057, 384, 933, 143, 644, 362, 1084, 120, 16, 1094, 706, 788, 233, 905, 1036, 596, 733, 853, 418, 254, 119, 26, 559, 575, 549, 1088, 834, 966, 202, 434, 1095, 95, 308, 507, 590, 154, 988, 357, 1008, 425, 360, 234, 778, 819, 883, 90, 793, 611, 1028, 901, 812, 740, 432, 923, 68, 289, 167, 679, 784, 1093, 186, 671, 891, 849, 316, 729, 669, 213, 375, 801, 304, 991, 569, 9, 633, 603, 428, 190, 940, 642, 368, 771, 857, 282, 263, 419, 330, 1031, 420, 655, 876, 520, 1041, 544, 194, 791, 81, 122, 859, 3, 656, 1042, 201, 956, 753, 266, 262, 782, 243, 334, 54, 99, 341, 140, 280, 1027, 421, 787, 43, 401, 96, 115, 989, 1071, 973, 56, 734, 1061, 617, 850, 30, 928, 534, 573, 1067, 968, 949, 499, 343, 868, 990, 837, 242, 1092, 482, 1083, 992, 568, 350, 723, 969, 612, 467, 344, 810, 503, 1087, 693, 807, 619, 174, 717, 1076, 594, 29, 624, 586, 718, 336, 809, 926, 616, 441, 409, 1045, 1053, 954, 134, 451, 501, 537, 731, 533, 1010, 1006, 491, 708, 206, 688, 125, 927, 113, 770, 459, 709, 453, 390, 200, 713, 402, 906, 152, 538, 25, 379, 654, 766, 188, 267, 493, 661, 137, 776, 121, 764, 687, 430, 978, 570, 800, 47, 82, 701, 489, 959, 396, 762, 1020, 35, 216, 939, 50, 1007, 1026, 23, 948, 702, 438, 248, 987, 878, 77, 155, 278, 745, 135, 178, 998, 309, 37, 676, 759, 775, 870, 79, 1048, 539, 712, 699, 148, 412, 613, 986, 385, 34, 446, 484, 915, 207, 774, 979, 199, 657, 378, 269, 184, 371, 271, 856, 845, 394, 532, 359, 997, 673, 476, 739, 478, 236, 821, 515, 372, 665, 865, 695, 831, 965, 285, 879, 892, 958, 331, 1000, 666, 1090, 889, 374, 1091, 974, 276, 156, 179, 585, 643, 226, 935, 0, 714, 781, 790, 510, 1049, 257, 223, 447, 293, 651, 530, 825, 69, 332, 370, 1052, 548, 1065, 525, 6, 382, 291, 474, 726, 413, 55, 415, 452, 104, 196, 292, 756, 519, 1078, 215, 647, 716, 913, 779, 437, 1013, 582, 866, 748, 829, 571, 705, 299, 527, 139, 698, 277, 75, 606, 88, 1066, 838, 427, 124, 648, 321, 15, 725, 296, 486, 270, 97, 2, 286, 383, 193, 127, 808, 151, 338, 854, 483, 1080, 1050, 307, 677, 61, 108, 636, 700, 318, 684, 546, 166, 909, 84, 696, 1035, 462, 602, 481, 195, 760, 860, 123, 473, 435, 227, 517, 358, 900, 337, 495, 715, 315, 397, 785, 142, 936, 737, 622, 769, 660, 1074, 168, 393, 356, 403, 1018, 53, 833, 855, 659, 159, 180, 386, 641, 126, 87, 475, 169, 1058, 599, 844, 902, 208, 543, 918, 377, 523, 73, 595, 224, 281, 1024, 422, 1085, 490, 506, 724, 244, 231, 1064, 512, 953, 541, 952, 7, 994, 239, 662, 176, 410, 407, 730, 914, 572, 749, 141, 363, 63, 19, 841, 903, 1043, 1069, 830, 1038, 405, 11, 1023, 597, 449, 670, 678, 562, 440, 457, 683, 839, 253, 49, 161, 550, 946, 306, 182, 820, 566, 323, 32, 558, 145, 211, 1097, 300, 1012, 109, 312, 938, 144, 153, 183, 1037, 826, 743, 652, 513, 943, 157, 674, 505, 367, 921, 692, 22, 76, 847, 757, 930, 752, 804, 249, 20, 872, 250, 94, 792, 649, 1029, 874, 103, 342, 251, 310, 592, 682, 191, 799, 42, 668, 811, 232, 985, 658, 588, 92, 458, 91, 929, 738, 369, 252, 203, 314, 710, 554, 187, 265, 364, 480, 704, 632, 220, 256, 114, 466, 615, 324, 65, 64, 408, 320, 400, 460, 327, 610, 727, 10, 27, 908, 325, 667, 212, 102, 322, 488, 894, 741, 910, 555, 744, 445, 105, 842, 1056, 697, 919, 1009, 118, 165, 899, 600, 245, 897, 754, 843, 971, 947, 932, 686, 628, 1021, 735, 46, 110, 283, 1081, 893, 786, 577, 302, 993, 273, 83, 579, 229, 951, 584, 846, 824, 601, 629, 117, 349, 851, 150, 389, 74, 416, 40, 858, 813, 945, 1062, 797, 500, 732, 618, 783, 298, 904, 1054, 346, 376, 917, 995, 957, 340, 274, 794, 964, 170, 173, 136, 86, 41, 742, 66, 240, 1082, 970, 547, 703, 922, 560, 1051, 98, 818, 272, 218, 439, 347, 138, 576, 1047, 955, 160, 885, 288, 411, 626, 333, 934, 511, 981, 303, 399, 1055, 106, 504, 198, 605, 1073, 690, 587, 1044, 101, 355, 205, 387, 835, 521, 637, 720, 1086, 175, 471, 976, 222, 604, 38, 984, 8, 133, 258, 578, 426, 162, 761, 873, 317, 78, 937, 313, 48, 217, 128, 305, 755, 1030, 875, 646, 1003, 328, 822, 589, 691, 404, 31, 664, 536, 228, 461, 528, 711, 925, 645, 225, 1060, 557, 982, 694, 518, 721 ] } ================================================ FILE: datainfo/double_pendulum/data_split_dict_2.json ================================================ { "test": [ 24, 688, 255, 176, 368, 371, 58, 32, 777, 734, 919, 433, 61, 901, 966, 215, 844, 250, 272, 874, 139, 534, 772, 108, 937, 896, 698, 232, 605, 1066, 50, 668, 588, 60, 217, 391, 17, 699, 154, 750, 1001, 425, 638, 832, 1032, 621, 633, 982, 1082, 340, 664, 454, 996, 935, 718, 1062, 931, 1057, 1025, 1020, 571, 1003, 511, 944, 818, 330, 1060, 741, 724, 1079, 849, 912, 372, 1052, 736, 1065, 1044, 279, 355, 665, 361, 48, 472, 483, 363, 336, 867, 778, 652, 952, 745, 56, 1090, 549, 1028, 911, 761, 1042, 805, 882, 324, 73, 434, 515, 631, 346, 739, 173, 187, 115 ], "val": [ 810, 639, 233, 1017, 82, 256, 75, 455, 731, 238, 252, 1041, 725, 520, 237, 650, 464, 98, 174, 165, 134, 34, 259, 995, 686, 143, 1049, 716, 18, 1048, 429, 620, 268, 265, 349, 924, 147, 335, 527, 498, 402, 799, 598, 530, 986, 895, 807, 458, 989, 104, 858, 323, 703, 676, 95, 230, 484, 157, 722, 636, 883, 411, 773, 270, 923, 46, 757, 619, 784, 564, 459, 315, 31, 500, 345, 292, 1098, 765, 760, 642, 630, 352, 4, 37, 155, 253, 813, 44, 603, 394, 1, 708, 535, 188, 752, 160, 958, 1033, 130, 261, 382, 21, 940, 746, 41, 25, 69, 977, 117, 84 ], "train": [ 521, 97, 212, 436, 465, 835, 410, 195, 591, 288, 988, 828, 450, 127, 42, 93, 92, 970, 873, 239, 424, 1086, 956, 310, 266, 824, 142, 717, 119, 727, 1008, 357, 871, 584, 64, 396, 817, 406, 57, 2, 681, 9, 249, 864, 512, 198, 245, 40, 473, 1089, 299, 182, 601, 379, 321, 955, 763, 486, 766, 22, 906, 744, 171, 316, 702, 467, 370, 898, 274, 820, 1015, 644, 443, 707, 13, 576, 764, 109, 838, 559, 158, 26, 670, 721, 329, 789, 798, 802, 767, 133, 431, 539, 325, 672, 485, 618, 140, 135, 438, 692, 408, 608, 997, 640, 29, 344, 572, 421, 748, 6, 327, 1039, 235, 308, 796, 635, 290, 581, 713, 178, 612, 943, 568, 866, 207, 546, 519, 917, 889, 214, 339, 111, 448, 35, 63, 248, 282, 405, 163, 629, 782, 226, 659, 850, 902, 487, 1058, 592, 540, 596, 505, 831, 801, 575, 1046, 800, 833, 963, 556, 758, 504, 347, 236, 206, 441, 682, 1085, 547, 1040, 507, 87, 1094, 580, 975, 604, 648, 569, 294, 7, 916, 607, 788, 228, 281, 573, 376, 219, 915, 241, 167, 208, 1080, 216, 529, 1075, 94, 690, 907, 542, 578, 146, 120, 968, 49, 862, 211, 680, 447, 815, 586, 1068, 488, 693, 306, 271, 269, 314, 43, 792, 797, 976, 417, 954, 191, 499, 151, 729, 854, 848, 552, 781, 307, 407, 826, 881, 985, 1097, 1069, 397, 423, 508, 168, 76, 998, 627, 439, 860, 258, 229, 105, 645, 590, 751, 377, 759, 791, 597, 80, 106, 887, 891, 386, 753, 1054, 322, 100, 809, 965, 300, 843, 661, 913, 899, 617, 283, 855, 567, 99, 842, 926, 737, 574, 634, 967, 440, 987, 948, 637, 960, 886, 416, 562, 51, 513, 1019, 662, 732, 803, 66, 356, 460, 419, 897, 152, 932, 193, 359, 945, 606, 476, 583, 660, 624, 786, 369, 466, 506, 110, 557, 403, 1012, 351, 378, 1091, 981, 276, 415, 0, 251, 427, 953, 432, 942, 770, 683, 616, 566, 877, 787, 677, 456, 257, 70, 15, 286, 71, 929, 33, 123, 689, 756, 518, 332, 827, 47, 780, 614, 790, 12, 876, 112, 632, 541, 671, 129, 845, 694, 234, 156, 201, 482, 490, 884, 348, 1022, 1021, 302, 380, 921, 738, 231, 1084, 197, 210, 723, 1016, 354, 205, 964, 1038, 517, 755, 593, 1064, 305, 918, 278, 691, 1005, 1006, 242, 675, 200, 27, 116, 492, 430, 1073, 1030, 949, 687, 59, 295, 267, 920, 613, 1009, 398, 227, 880, 890, 90, 1081, 951, 785, 865, 331, 190, 172, 23, 312, 1061, 122, 362, 126, 646, 175, 45, 442, 678, 563, 446, 128, 169, 857, 775, 719, 1070, 991, 806, 623, 754, 77, 615, 684, 145, 463, 199, 304, 714, 666, 816, 863, 277, 868, 223, 333, 841, 878, 532, 808, 469, 39, 183, 85, 179, 53, 14, 65, 861, 704, 565, 1010, 457, 973, 244, 225, 647, 461, 132, 859, 247, 202, 936, 972, 962, 343, 544, 318, 1007, 149, 747, 959, 611, 385, 930, 479, 360, 445, 365, 384, 1092, 328, 287, 350, 409, 387, 743, 555, 1083, 1011, 728, 189, 776, 194, 480, 243, 101, 162, 334, 1018, 643, 649, 137, 847, 298, 452, 749, 170, 994, 554, 730, 79, 1043, 326, 218, 373, 388, 1037, 1045, 62, 284, 971, 705, 836, 888, 983, 622, 91, 516, 153, 543, 957, 837, 892, 933, 922, 673, 319, 341, 311, 651, 264, 783, 309, 551, 742, 136, 510, 879, 946, 453, 577, 654, 550, 900, 814, 969, 641, 570, 289, 10, 1095, 301, 653, 390, 1051, 166, 55, 138, 978, 908, 1034, 1067, 81, 28, 471, 514, 273, 470, 762, 496, 829, 67, 658, 795, 3, 358, 74, 209, 478, 177, 579, 525, 667, 493, 834, 8, 141, 626, 501, 779, 594, 600, 669, 204, 905, 812, 856, 220, 395, 910, 582, 221, 655, 1047, 823, 181, 822, 1026, 846, 320, 11, 894, 1024, 412, 192, 938, 545, 934, 83, 338, 560, 875, 1072, 710, 925, 1002, 414, 280, 494, 54, 522, 999, 184, 392, 609, 30, 451, 477, 148, 587, 367, 413, 240, 1071, 161, 1096, 263, 526, 254, 528, 1074, 720, 872, 553, 159, 114, 293, 536, 992, 449, 337, 711, 903, 774, 399, 735, 131, 38, 663, 697, 88, 68, 426, 150, 558, 180, 364, 72, 1014, 589, 599, 86, 715, 695, 712, 437, 561, 1055, 610, 1078, 974, 852, 769, 303, 696, 503, 909, 1063, 489, 870, 495, 275, 383, 144, 468, 366, 297, 509, 1013, 947, 1050, 481, 674, 1035, 1059, 1087, 709, 840, 400, 656, 342, 1093, 939, 125, 853, 260, 124, 404, 353, 771, 904, 794, 733, 401, 979, 585, 118, 196, 502, 941, 224, 78, 740, 804, 811, 927, 291, 5, 885, 628, 602, 213, 474, 497, 1077, 420, 1029, 462, 16, 990, 793, 203, 313, 851, 103, 422, 839, 1053, 950, 381, 706, 296, 375, 625, 121, 531, 19, 374, 491, 821, 679, 96, 185, 830, 537, 428, 52, 1088, 595, 980, 523, 435, 444, 825, 1036, 1004, 1076, 701, 1023, 389, 657, 548, 317, 914, 475, 685, 533, 984, 222, 107, 893, 869, 186, 20, 102, 928, 246, 89, 1056, 524, 113, 164, 418, 393, 36, 1027, 961, 768, 538, 285, 1000, 700, 262, 993, 726, 1031, 819 ] } ================================================ FILE: datainfo/double_pendulum/data_split_dict_3.json ================================================ { "test": [ 887, 659, 395, 532, 890, 471, 385, 386, 882, 918, 141, 981, 368, 321, 347, 888, 1003, 43, 706, 159, 269, 625, 298, 417, 994, 202, 971, 32, 548, 614, 110, 78, 7, 317, 1065, 483, 571, 677, 773, 92, 90, 243, 850, 1082, 601, 41, 1085, 840, 136, 704, 181, 990, 987, 129, 254, 583, 546, 432, 213, 668, 334, 572, 58, 689, 475, 834, 718, 790, 1038, 862, 1076, 893, 528, 444, 1013, 278, 73, 199, 748, 274, 910, 808, 874, 793, 968, 551, 63, 616, 87, 326, 131, 31, 798, 1071, 310, 474, 308, 813, 975, 963, 392, 479, 531, 960, 26, 134, 970, 757, 267, 487 ], "val": [ 223, 897, 759, 344, 81, 173, 876, 829, 448, 229, 853, 691, 341, 329, 930, 104, 99, 714, 665, 445, 694, 191, 335, 244, 275, 823, 669, 227, 512, 1022, 1094, 752, 700, 582, 27, 832, 1056, 107, 996, 806, 307, 270, 988, 864, 936, 776, 320, 189, 372, 1039, 327, 615, 606, 305, 467, 643, 257, 378, 21, 692, 62, 974, 22, 501, 755, 285, 723, 623, 950, 939, 695, 361, 477, 340, 642, 648, 61, 1037, 647, 603, 630, 995, 20, 322, 593, 425, 807, 11, 1005, 561, 1083, 533, 264, 447, 1035, 958, 1030, 732, 737, 649, 441, 277, 519, 830, 1088, 635, 105, 1050, 697, 609 ], "train": [ 401, 148, 464, 711, 957, 1084, 420, 1007, 318, 1074, 579, 972, 434, 2, 770, 220, 1010, 540, 192, 740, 857, 788, 545, 408, 352, 163, 469, 118, 758, 534, 506, 587, 504, 396, 863, 760, 231, 794, 1059, 1049, 768, 655, 456, 279, 313, 459, 769, 371, 80, 747, 66, 67, 525, 804, 521, 144, 1081, 894, 1060, 1066, 682, 439, 476, 142, 576, 620, 38, 672, 909, 345, 328, 1019, 193, 1087, 713, 179, 413, 156, 607, 640, 898, 1046, 511, 1027, 253, 558, 135, 781, 1098, 599, 945, 1073, 157, 905, 221, 197, 177, 59, 219, 154, 178, 486, 421, 727, 64, 53, 402, 1097, 355, 751, 228, 646, 631, 852, 452, 627, 1052, 309, 292, 462, 325, 784, 555, 83, 225, 149, 350, 1025, 1031, 468, 427, 969, 49, 846, 836, 262, 473, 1012, 520, 12, 573, 60, 212, 1048, 973, 820, 671, 374, 460, 249, 859, 171, 172, 198, 218, 407, 746, 908, 502, 112, 1036, 450, 594, 1089, 999, 931, 809, 301, 137, 236, 819, 639, 753, 843, 1002, 1069, 75, 207, 100, 214, 300, 346, 720, 8, 578, 89, 188, 272, 1070, 418, 84, 725, 167, 238, 875, 800, 470, 390, 1055, 835, 211, 312, 217, 656, 0, 405, 151, 941, 1026, 174, 375, 586, 304, 814, 315, 265, 855, 1072, 929, 143, 565, 778, 485, 779, 337, 201, 562, 399, 186, 709, 500, 608, 1080, 362, 29, 743, 621, 42, 48, 121, 949, 1040, 696, 266, 557, 685, 6, 690, 624, 10, 1044, 955, 952, 719, 424, 16, 680, 162, 415, 287, 575, 821, 411, 363, 780, 503, 754, 71, 14, 86, 935, 702, 394, 299, 717, 721, 921, 589, 977, 693, 658, 889, 916, 496, 114, 296, 185, 938, 842, 712, 792, 772, 19, 233, 454, 224, 1018, 896, 789, 966, 155, 698, 106, 216, 687, 872, 242, 449, 404, 82, 535, 480, 209, 998, 482, 629, 248, 113, 1041, 705, 1016, 673, 775, 40, 17, 370, 1034, 68, 841, 168, 165, 507, 9, 801, 762, 39, 259, 364, 316, 633, 803, 25, 319, 666, 239, 708, 914, 728, 145, 1057, 412, 815, 472, 518, 699, 108, 619, 397, 735, 564, 652, 684, 745, 251, 119, 338, 951, 943, 379, 97, 797, 387, 559, 1061, 466, 458, 351, 844, 1032, 324, 158, 679, 282, 161, 443, 208, 828, 15, 899, 605, 526, 103, 670, 653, 1045, 543, 451, 517, 597, 724, 976, 303, 65, 516, 1000, 574, 343, 1047, 290, 96, 357, 674, 416, 654, 667, 288, 617, 1058, 1095, 513, 437, 77, 57, 764, 498, 457, 736, 414, 822, 993, 206, 170, 928, 263, 393, 356, 681, 831, 628, 774, 937, 816, 913, 196, 138, 956, 750, 660, 200, 923, 294, 1064, 283, 510, 1028, 366, 707, 810, 947, 650, 595, 610, 246, 234, 95, 331, 153, 284, 854, 932, 638, 1001, 783, 169, 481, 72, 286, 55, 1075, 817, 901, 911, 455, 926, 91, 76, 741, 373, 489, 756, 777, 380, 505, 927, 967, 596, 367, 536, 180, 560, 252, 891, 1051, 1009, 585, 289, 538, 552, 946, 1024, 917, 644, 730, 591, 877, 377, 686, 895, 256, 989, 115, 442, 1020, 281, 924, 676, 786, 400, 742, 215, 984, 497, 388, 446, 339, 342, 30, 824, 997, 463, 365, 839, 851, 920, 664, 120, 866, 93, 881, 922, 925, 641, 884, 261, 845, 900, 140, 1015, 194, 867, 260, 865, 749, 715, 235, 478, 811, 767, 74, 802, 688, 570, 1090, 210, 491, 622, 147, 1067, 892, 645, 731, 176, 205, 391, 490, 1068, 23, 314, 222, 598, 791, 302, 550, 237, 703, 744, 150, 965, 273, 139, 190, 733, 453, 612, 611, 509, 98, 166, 152, 410, 837, 495, 529, 983, 878, 964, 164, 422, 240, 907, 661, 85, 701, 569, 358, 406, 613, 376, 785, 403, 959, 465, 255, 636, 247, 79, 184, 954, 332, 563, 782, 566, 102, 567, 109, 796, 541, 523, 663, 37, 1029, 306, 592, 291, 493, 117, 1093, 683, 787, 24, 183, 384, 101, 675, 94, 761, 45, 886, 241, 146, 1096, 5, 1033, 985, 544, 577, 1078, 111, 1092, 871, 795, 726, 860, 124, 268, 982, 383, 333, 28, 133, 436, 1004, 435, 738, 604, 276, 1017, 879, 849, 18, 716, 428, 430, 739, 508, 953, 381, 915, 979, 833, 44, 651, 722, 554, 880, 618, 1063, 382, 1006, 182, 580, 883, 389, 3, 568, 130, 126, 438, 336, 870, 271, 369, 311, 600, 398, 662, 991, 359, 1023, 869, 160, 323, 942, 514, 527, 88, 729, 827, 494, 70, 766, 127, 47, 1042, 56, 1, 330, 484, 52, 433, 992, 539, 632, 903, 1043, 848, 51, 409, 584, 934, 499, 1021, 1079, 280, 245, 799, 1086, 547, 125, 1014, 226, 360, 948, 1054, 553, 258, 128, 818, 116, 626, 54, 838, 549, 537, 961, 904, 902, 919, 515, 175, 873, 861, 492, 13, 50, 590, 440, 203, 524, 710, 35, 46, 250, 122, 293, 602, 69, 232, 349, 556, 295, 1091, 765, 678, 771, 933, 763, 33, 906, 734, 986, 1062, 522, 637, 488, 4, 204, 1008, 423, 36, 419, 581, 429, 297, 426, 978, 354, 1053, 885, 812, 530, 1011, 431, 132, 940, 353, 634, 825, 1077, 657, 847, 868, 348, 944, 187, 588, 858, 856, 826, 962, 195, 542, 34, 123, 805, 230, 980, 461, 912 ] } ================================================ FILE: datainfo/elastic_pendulum/data_split_dict_1.json ================================================ { "test": [ 187, 370, 362, 470, 57, 938, 678, 3, 708, 1137, 730, 993, 846, 1033, 409, 746, 985, 114, 872, 420, 1062, 264, 1049, 785, 11, 551, 940, 723, 704, 1052, 828, 475, 1105, 408, 25, 464, 1028, 345, 348, 806, 631, 89, 961, 60, 1002, 758, 1142, 1066, 335, 221, 1041, 898, 177, 767, 1123, 751, 354, 848, 827, 497, 983, 70, 805, 1034, 1022, 581, 621, 388, 1039, 1171, 1025, 681, 247, 607, 380, 204, 1139, 852, 44, 593, 941, 448, 472, 707, 477, 1132, 1015, 896, 454, 1080, 59, 864, 443, 780, 18, 1108, 52, 45, 62, 650, 209, 468, 545, 912, 4, 886, 798, 58, 999, 192, 429, 777, 967, 920, 1014, 241, 522, 129, 1165, 275 ], "val": [ 411, 892, 626, 333, 17, 1071, 516, 303, 399, 1151, 960, 106, 504, 198, 605, 1172, 918, 690, 587, 210, 101, 355, 205, 387, 1000, 521, 637, 720, 1186, 890, 888, 847, 175, 471, 583, 922, 1096, 1046, 839, 604, 38, 870, 899, 574, 8, 133, 258, 578, 426, 162, 761, 305, 1122, 939, 317, 78, 879, 1036, 313, 1053, 1167, 217, 991, 257, 611, 120, 1093, 657, 808, 1178, 457, 923, 451, 873, 1183, 328, 72, 299, 813, 36, 461, 42, 884, 428, 1044, 519, 222, 529, 385, 862, 703, 791, 638, 48, 233, 970, 1016, 659, 931, 603, 558, 344, 1079, 326, 342, 142, 594, 705, 378, 224, 550, 511, 575, 29, 927, 34, 170, 144, 66, 1196 ], "train": [ 483, 489, 341, 685, 96, 1, 1148, 542, 656, 744, 1085, 613, 855, 775, 680, 111, 1081, 648, 308, 733, 1023, 189, 759, 61, 482, 639, 1190, 228, 776, 596, 885, 1024, 1187, 1048, 295, 1158, 1146, 1094, 819, 684, 540, 793, 856, 1099, 622, 262, 773, 609, 107, 263, 260, 1160, 958, 474, 672, 739, 442, 478, 887, 1083, 734, 612, 1084, 956, 843, 55, 1067, 750, 823, 9, 532, 112, 282, 329, 520, 874, 359, 1125, 1102, 514, 668, 1164, 459, 197, 844, 753, 635, 633, 185, 1195, 966, 201, 906, 1059, 569, 146, 580, 39, 1047, 446, 597, 276, 43, 1042, 533, 561, 166, 1003, 154, 778, 647, 243, 67, 768, 617, 1149, 634, 981, 649, 450, 507, 215, 1057, 740, 691, 710, 1120, 858, 889, 682, 71, 531, 869, 415, 849, 444, 237, 716, 214, 764, 735, 765, 595, 486, 90, 834, 989, 437, 292, 1127, 549, 330, 16, 304, 559, 1168, 104, 186, 916, 719, 99, 623, 476, 293, 498, 286, 1136, 905, 503, 143, 952, 917, 167, 242, 26, 797, 149, 315, 652, 897, 871, 234, 271, 223, 1029, 769, 238, 673, 924, 294, 1199, 567, 33, 207, 821, 945, 795, 895, 548, 784, 729, 987, 412, 908, 986, 965, 375, 427, 6, 891, 456, 195, 980, 641, 435, 419, 1004, 925, 721, 538, 80, 718, 762, 679, 1063, 131, 937, 311, 307, 552, 485, 1054, 1043, 959, 1176, 840, 1037, 432, 699, 1177, 812, 424, 1157, 752, 360, 190, 68, 316, 693, 1194, 383, 384, 951, 517, 741, 1110, 1072, 671, 676, 219, 179, 1121, 30, 396, 1106, 1005, 585, 402, 321, 954, 943, 606, 749, 255, 859, 530, 865, 390, 395, 976, 877, 556, 268, 1007, 755, 957, 909, 702, 933, 509, 929, 130, 582, 289, 1073, 277, 756, 973, 1100, 394, 525, 81, 598, 100, 147, 379, 196, 644, 851, 2, 782, 351, 824, 1090, 56, 995, 586, 636, 919, 660, 108, 199, 1045, 280, 1091, 1107, 553, 757, 421, 193, 655, 653, 119, 494, 492, 1019, 675, 343, 1134, 140, 339, 692, 453, 1018, 438, 1163, 964, 571, 695, 640, 28, 696, 95, 51, 510, 164, 493, 910, 54, 1188, 156, 911, 998, 1170, 123, 838, 158, 184, 760, 401, 371, 1141, 974, 269, 235, 134, 770, 75, 372, 747, 227, 296, 1032, 374, 1173, 487, 97, 434, 127, 893, 1129, 841, 206, 591, 15, 338, 125, 361, 619, 694, 113, 366, 599, 413, 469, 404, 842, 88, 1155, 787, 200, 1092, 1068, 738, 714, 463, 630, 543, 913, 334, 152, 481, 1114, 645, 285, 462, 854, 188, 337, 267, 425, 357, 984, 717, 137, 1031, 121, 947, 270, 1131, 473, 484, 1086, 1013, 836, 570, 47, 353, 82, 1088, 832, 84, 501, 226, 336, 1075, 666, 35, 216, 781, 508, 50, 465, 491, 23, 763, 669, 926, 674, 132, 687, 236, 77, 155, 278, 350, 1058, 715, 1082, 534, 135, 467, 178, 5, 309, 397, 417, 37, 754, 670, 433, 902, 815, 524, 479, 544, 664, 79, 701, 789, 528, 748, 829, 148, 332, 495, 904, 539, 381, 745, 790, 69, 1061, 414, 624, 398, 727, 368, 202, 537, 174, 535, 590, 1162, 1051, 643, 978, 1087, 642, 602, 915, 115, 972, 779, 1038, 731, 358, 816, 452, 850, 0, 331, 1145, 833, 515, 447, 568, 382, 663, 1135, 139, 1192, 13, 589, 689, 867, 290, 1180, 392, 213, 124, 430, 502, 365, 722, 172, 935, 651, 318, 279, 955, 950, 151, 291, 736, 724, 266, 248, 31, 21, 883, 665, 194, 1184, 572, 837, 254, 284, 820, 1150, 418, 546, 441, 122, 1021, 499, 725, 662, 963, 809, 406, 712, 391, 246, 455, 1124, 1065, 85, 875, 928, 914, 783, 536, 1074, 766, 168, 527, 393, 356, 403, 1027, 53, 994, 997, 436, 700, 159, 180, 386, 860, 831, 620, 126, 87, 608, 1111, 169, 1154, 565, 713, 1011, 319, 208, 646, 163, 377, 523, 73, 709, 1069, 281, 625, 573, 1117, 422, 230, 490, 506, 814, 244, 231, 654, 1161, 512, 772, 541, 181, 7, 944, 239, 932, 627, 176, 410, 1098, 407, 868, 802, 737, 907, 1153, 141, 901, 363, 63, 19, 880, 1008, 661, 24, 1156, 1101, 992, 1143, 405, 1104, 1115, 711, 449, 794, 1119, 562, 440, 1030, 698, 811, 1006, 253, 683, 49, 161, 1070, 975, 306, 182, 979, 706, 566, 688, 1109, 323, 32, 1060, 145, 211, 1197, 300, 616, 526, 726, 109, 312, 817, 1077, 153, 183, 1198, 969, 996, 881, 774, 513, 1133, 157, 799, 505, 367, 297, 822, 1179, 22, 76, 1095, 1017, 900, 1193, 894, 1055, 249, 20, 225, 250, 94, 818, 1103, 962, 557, 103, 1064, 251, 310, 592, 810, 191, 953, 1130, 792, 968, 232, 948, 658, 588, 826, 92, 458, 771, 91, 287, 876, 369, 252, 203, 314, 1138, 554, 1169, 265, 364, 677, 480, 1175, 835, 796, 632, 803, 220, 256, 1097, 466, 615, 324, 65, 64, 1113, 320, 400, 460, 327, 610, 743, 863, 866, 10, 27, 853, 325, 667, 212, 102, 322, 488, 728, 259, 1056, 301, 555, 825, 882, 445, 105, 1010, 1009, 1152, 697, 171, 1040, 118, 165, 431, 600, 804, 245, 1189, 1020, 1116, 845, 1001, 423, 93, 14, 686, 628, 12, 1026, 1144, 46, 1126, 110, 283, 1181, 1012, 934, 577, 302, 373, 273, 83, 579, 229, 563, 584, 1166, 1140, 800, 601, 629, 117, 349, 128, 1089, 150, 807, 1185, 389, 74, 416, 40, 1035, 971, 788, 564, 1159, 949, 500, 732, 988, 618, 990, 930, 298, 116, 1118, 346, 376, 261, 861, 518, 614, 340, 1191, 274, 946, 1112, 1076, 173, 136, 86, 41, 742, 1078, 240, 1182, 786, 496, 547, 1050, 942, 903, 936, 352, 560, 1147, 857, 98, 977, 272, 218, 439, 347, 138, 801, 576, 830, 1128, 878, 982, 160, 1174, 288, 921 ] } ================================================ FILE: datainfo/elastic_pendulum/data_split_dict_2.json ================================================ { "test": [ 789, 3, 1071, 376, 321, 261, 523, 764, 43, 83, 51, 138, 235, 169, 1167, 510, 352, 737, 742, 116, 65, 866, 123, 431, 501, 544, 1163, 1069, 1113, 464, 559, 100, 120, 217, 391, 17, 699, 154, 750, 1048, 1001, 425, 638, 832, 1039, 1060, 1032, 621, 633, 982, 1181, 340, 664, 454, 996, 935, 718, 1171, 931, 1152, 1055, 1025, 1020, 571, 1003, 511, 1086, 944, 818, 330, 1156, 741, 724, 1178, 1075, 849, 912, 372, 1146, 1052, 736, 1162, 1044, 279, 355, 665, 361, 48, 472, 483, 363, 1147, 336, 1076, 867, 778, 652, 952, 745, 56, 1191, 549, 1028, 911, 1114, 761, 1042, 805, 882, 324, 1190, 73, 434, 515, 631, 346, 739, 173, 187, 115 ], "val": [ 954, 706, 296, 375, 625, 121, 943, 531, 19, 374, 491, 821, 679, 96, 942, 185, 983, 537, 951, 428, 902, 52, 605, 595, 21, 1158, 435, 444, 825, 746, 215, 986, 1175, 1037, 701, 1108, 389, 657, 548, 317, 1109, 475, 685, 533, 25, 222, 107, 237, 768, 186, 20, 102, 104, 246, 89, 1151, 524, 113, 1029, 418, 393, 1046, 1117, 9, 570, 1101, 525, 1160, 467, 164, 513, 150, 910, 476, 505, 64, 474, 928, 196, 349, 1149, 269, 68, 518, 1099, 287, 36, 859, 536, 530, 698, 294, 671, 997, 804, 1085, 917, 49, 208, 647, 191, 461, 968, 314, 822, 540, 93, 918, 1194, 63, 1000, 690, 585, 231, 704, 8, 74, 310, 507, 88 ], "train": [ 689, 907, 403, 285, 1073, 796, 316, 140, 1183, 132, 861, 1047, 167, 556, 985, 999, 1185, 891, 984, 199, 717, 40, 937, 354, 405, 271, 1155, 1125, 262, 900, 292, 304, 151, 909, 397, 580, 820, 646, 667, 1057, 183, 684, 760, 450, 1159, 726, 198, 747, 880, 1012, 335, 719, 39, 913, 84, 168, 207, 1051, 396, 618, 1088, 834, 567, 146, 923, 814, 921, 629, 299, 1139, 45, 734, 216, 758, 302, 424, 128, 60, 406, 85, 447, 205, 710, 547, 566, 32, 156, 18, 892, 105, 1170, 239, 659, 1061, 797, 76, 248, 541, 641, 783, 969, 1199, 508, 94, 58, 394, 639, 863, 479, 819, 1034, 487, 947, 862, 283, 305, 327, 438, 744, 975, 386, 723, 202, 756, 979, 236, 1092, 1094, 774, 343, 219, 920, 360, 945, 601, 810, 998, 568, 932, 197, 603, 830, 527, 808, 940, 46, 792, 875, 816, 1005, 506, 916, 634, 976, 244, 1009, 2, 1198, 290, 766, 1182, 840, 417, 182, 249, 210, 443, 623, 307, 590, 1043, 592, 925, 155, 858, 675, 90, 427, 442, 848, 846, 1093, 692, 1030, 1131, 35, 255, 111, 1166, 469, 1016, 22, 614, 377, 1018, 26, 195, 499, 165, 37, 333, 883, 899, 247, 906, 981, 288, 365, 616, 693, 1070, 7, 1135, 441, 1179, 874, 188, 992, 851, 347, 1054, 456, 258, 991, 529, 1110, 1066, 970, 69, 370, 457, 735, 1137, 546, 1050, 13, 879, 312, 348, 872, 23, 226, 611, 1130, 678, 445, 57, 557, 569, 853, 668, 532, 272, 806, 379, 977, 974, 223, 266, 385, 127, 971, 31, 562, 1112, 1058, 624, 753, 798, 1161, 300, 281, 980, 552, 253, 42, 142, 228, 486, 573, 436, 1103, 407, 459, 781, 1023, 1142, 780, 1077, 225, 175, 517, 714, 163, 29, 1157, 622, 14, 6, 961, 473, 274, 660, 157, 329, 295, 1017, 241, 617, 1065, 1021, 158, 731, 828, 811, 1064, 227, 378, 888, 565, 1063, 586, 1116, 411, 1098, 707, 232, 229, 87, 1100, 423, 833, 1041, 838, 1134, 612, 133, 896, 214, 109, 578, 946, 887, 645, 398, 1121, 539, 757, 670, 212, 211, 687, 572, 865, 1196, 700, 80, 584, 149, 588, 662, 351, 1136, 972, 542, 708, 325, 521, 455, 1111, 1096, 957, 432, 306, 786, 655, 1033, 458, 591, 850, 1193, 134, 380, 1174, 1148, 1144, 1040, 1010, 962, 200, 99, 683, 876, 384, 282, 308, 482, 1192, 1123, 857, 357, 894, 583, 669, 117, 615, 250, 785, 466, 564, 139, 575, 1090, 607, 50, 415, 1053, 1097, 868, 950, 66, 498, 1180, 1118, 174, 152, 953, 193, 967, 898, 1084, 77, 1019, 847, 53, 705, 776, 345, 649, 1056, 268, 581, 201, 490, 331, 318, 635, 619, 672, 519, 598, 27, 632, 108, 492, 534, 126, 59, 276, 812, 0, 251, 855, 484, 122, 448, 270, 1124, 596, 267, 416, 941, 835, 949, 964, 749, 930, 800, 277, 864, 356, 70, 15, 286, 71, 1013, 33, 1102, 1127, 460, 439, 897, 613, 332, 419, 172, 784, 504, 703, 933, 12, 844, 112, 606, 772, 948, 129, 4, 593, 408, 234, 993, 666, 637, 179, 512, 965, 839, 966, 110, 676, 463, 751, 421, 815, 465, 410, 1078, 794, 92, 1189, 1015, 485, 1067, 955, 233, 597, 1195, 119, 1059, 576, 135, 278, 1133, 440, 1035, 759, 147, 1107, 929, 344, 47, 245, 252, 914, 608, 654, 256, 369, 430, 885, 257, 339, 769, 145, 97, 680, 206, 323, 807, 1022, 752, 1138, 554, 787, 242, 604, 555, 738, 446, 721, 809, 171, 359, 106, 711, 130, 642, 322, 190, 1014, 488, 802, 362, 673, 178, 677, 904, 653, 1083, 788, 1105, 630, 799, 328, 775, 713, 1045, 350, 409, 387, 563, 1188, 651, 702, 915, 1122, 189, 919, 194, 480, 243, 101, 162, 334, 1106, 755, 905, 137, 813, 298, 452, 889, 170, 371, 767, 716, 79, 852, 326, 218, 373, 388, 644, 827, 62, 284, 535, 836, 574, 886, 238, 41, 681, 91, 643, 516, 153, 543, 1141, 901, 990, 648, 520, 95, 1049, 893, 658, 319, 661, 341, 311, 765, 264, 926, 309, 551, 881, 136, 1132, 1095, 958, 762, 453, 577, 1008, 550, 34, 963, 1, 729, 1024, 289, 10, 826, 301, 1074, 1087, 390, 1145, 166, 55, 1091, 640, 691, 1140, 790, 779, 1164, 934, 81, 28, 471, 514, 273, 470, 903, 1082, 496, 1129, 67, 777, 1002, 824, 939, 1081, 358, 1173, 209, 478, 177, 579, 1026, 1080, 493, 987, 694, 627, 1154, 141, 626, 1104, 922, 594, 600, 791, 204, 1143, 873, 1011, 220, 395, 620, 582, 1187, 221, 770, 1115, 973, 181, 1165, 803, 728, 842, 1120, 320, 11, 650, 1184, 682, 412, 192, 636, 545, 230, 743, 1089, 956, 763, 730, 338, 560, 368, 1169, 841, 890, 960, 748, 727, 259, 414, 280, 494, 54, 522, 402, 184, 754, 392, 609, 30, 878, 451, 477, 148, 989, 587, 782, 801, 367, 413, 240, 720, 1168, 161, 1197, 263, 526, 732, 254, 528, 1172, 854, 433, 924, 553, 159, 114, 293, 1119, 176, 449, 337, 843, 686, 725, 869, 399, 871, 131, 38, 663, 829, 697, 1079, 884, 1186, 426, 988, 1031, 558, 180, 364, 72, 98, 589, 837, 877, 599, 86, 715, 695, 712, 437, 561, 265, 610, 500, 160, 1007, 1150, 303, 696, 856, 503, 429, 823, 1153, 1027, 489, 538, 495, 275, 383, 817, 144, 468, 366, 297, 509, 722, 927, 44, 795, 481, 674, 1128, 1004, 1177, 709, 995, 400, 656, 342, 870, 1068, 82, 125, 936, 260, 124, 908, 404, 353, 771, 143, 938, 733, 401, 382, 1072, 118, 1038, 502, 773, 224, 78, 740, 978, 959, 24, 291, 5, 75, 628, 602, 1126, 213, 1036, 497, 1176, 420, 61, 462, 831, 16, 845, 688, 793, 860, 203, 313, 1006, 103, 422, 994, 895, 1062, 315, 381 ] } ================================================ FILE: datainfo/elastic_pendulum/data_split_dict_3.json ================================================ { "test": [ 1194, 644, 1184, 23, 694, 620, 1067, 1157, 895, 1155, 486, 883, 555, 1153, 210, 1152, 1065, 942, 771, 1121, 283, 737, 642, 695, 86, 319, 539, 597, 835, 404, 64, 1096, 221, 157, 14, 634, 1161, 483, 1035, 571, 677, 773, 92, 90, 243, 850, 1167, 601, 41, 1181, 840, 136, 704, 181, 990, 987, 129, 254, 583, 546, 432, 213, 1109, 668, 334, 572, 58, 689, 475, 834, 1093, 718, 790, 1038, 862, 1172, 893, 528, 444, 1013, 278, 73, 199, 748, 274, 910, 808, 874, 793, 968, 551, 63, 616, 87, 326, 131, 31, 798, 1071, 310, 474, 308, 813, 975, 1125, 1107, 963, 392, 479, 1128, 531, 960, 26, 134, 1189, 970, 757, 267, 1114, 487 ], "val": [ 710, 822, 925, 35, 46, 250, 829, 122, 293, 602, 878, 1029, 882, 232, 911, 349, 556, 295, 1190, 765, 678, 1079, 856, 467, 763, 33, 227, 734, 949, 972, 1145, 1158, 522, 637, 901, 852, 965, 1047, 4, 204, 159, 423, 1097, 36, 419, 581, 429, 297, 426, 649, 354, 1148, 277, 870, 812, 530, 298, 431, 132, 603, 353, 1051, 825, 175, 696, 570, 375, 645, 390, 69, 247, 460, 554, 923, 446, 1147, 163, 346, 897, 459, 683, 659, 208, 198, 891, 383, 671, 488, 1170, 455, 1024, 1115, 269, 55, 214, 772, 615, 540, 1086, 640, 379, 745, 363, 655, 611, 934, 514, 756, 43, 124, 45, 1002, 1119, 1074, 722, 954, 680, 123, 272, 1098 ], "train": [ 464, 1062, 890, 22, 660, 427, 290, 167, 469, 1159, 341, 712, 560, 457, 77, 952, 639, 462, 1185, 1195, 654, 135, 83, 797, 730, 372, 650, 218, 225, 787, 285, 265, 343, 279, 452, 1028, 397, 951, 948, 421, 855, 625, 158, 846, 1004, 1009, 753, 72, 466, 559, 914, 80, 591, 738, 142, 1120, 847, 408, 708, 973, 552, 236, 364, 1186, 207, 606, 1103, 338, 741, 622, 282, 178, 991, 263, 1054, 1112, 612, 586, 1072, 1193, 371, 1129, 480, 66, 1176, 781, 691, 507, 735, 853, 91, 575, 386, 192, 928, 67, 335, 785, 656, 585, 220, 789, 374, 347, 706, 231, 196, 188, 752, 681, 348, 324, 1177, 309, 151, 727, 929, 525, 688, 100, 95, 1141, 1124, 352, 519, 1090, 624, 1075, 623, 266, 351, 219, 723, 1133, 470, 1085, 75, 38, 1150, 510, 1031, 998, 800, 8, 641, 676, 370, 104, 288, 766, 415, 286, 292, 521, 411, 42, 1182, 443, 1110, 1092, 811, 873, 1073, 731, 842, 977, 350, 561, 788, 424, 1179, 312, 876, 906, 141, 321, 228, 564, 294, 791, 414, 922, 76, 118, 328, 138, 1166, 1156, 356, 628, 714, 841, 777, 1142, 799, 803, 103, 875, 844, 638, 1063, 1055, 726, 953, 1113, 313, 518, 315, 449, 935, 1160, 65, 881, 482, 6, 633, 784, 296, 498, 10, 1081, 337, 331, 865, 867, 434, 1169, 380, 605, 394, 154, 889, 884, 393, 412, 143, 217, 500, 587, 783, 815, 1052, 945, 535, 262, 961, 1007, 635, 12, 284, 61, 1089, 767, 672, 1076, 1197, 607, 717, 1117, 327, 999, 866, 1149, 0, 399, 1010, 345, 230, 994, 667, 121, 1036, 49, 1144, 1199, 1101, 992, 489, 451, 249, 533, 979, 927, 137, 666, 595, 496, 770, 959, 212, 534, 1099, 978, 1100, 197, 89, 697, 686, 171, 913, 1017, 477, 1146, 211, 558, 827, 1108, 201, 156, 619, 1039, 179, 915, 1088, 679, 437, 169, 657, 96, 27, 246, 502, 627, 684, 473, 647, 1005, 53, 739, 1046, 303, 485, 851, 29, 955, 373, 576, 699, 172, 15, 1131, 287, 930, 725, 589, 238, 653, 148, 454, 304, 1171, 511, 59, 318, 170, 84, 97, 857, 1162, 48, 947, 112, 481, 252, 703, 941, 1026, 617, 981, 837, 1021, 16, 1042, 162, 664, 838, 760, 517, 1059, 251, 1033, 420, 1154, 886, 71, 1198, 325, 1104, 257, 1015, 299, 830, 1084, 614, 848, 506, 721, 966, 849, 715, 596, 1018, 107, 888, 759, 418, 305, 669, 504, 1014, 608, 1027, 536, 368, 19, 233, 1122, 224, 796, 1020, 447, 234, 698, 11, 545, 956, 750, 106, 216, 907, 34, 1135, 242, 899, 200, 82, 818, 450, 1140, 476, 209, 21, 402, 1053, 161, 248, 113, 387, 357, 819, 543, 833, 32, 1095, 1048, 916, 1087, 40, 366, 17, 396, 578, 1006, 439, 1080, 68, 983, 355, 682, 445, 168, 165, 648, 9, 898, 743, 39, 259, 119, 316, 1061, 950, 25, 472, 574, 782, 405, 513, 971, 1056, 456, 858, 674, 145, 289, 57, 986, 786, 1044, 1180, 709, 416, 821, 108, 1019, 940, 693, 828, 744, 646, 805, 177, 503, 239, 905, 2, 859, 526, 824, 764, 920, 317, 871, 562, 417, 1049, 631, 300, 378, 610, 206, 588, 174, 1043, 407, 149, 912, 621, 1008, 193, 189, 1139, 673, 565, 817, 155, 114, 816, 361, 806, 458, 413, 340, 362, 700, 253, 794, 516, 685, 845, 705, 832, 1187, 900, 1078, 754, 367, 401, 976, 755, 185, 969, 1041, 180, 1105, 768, 468, 652, 944, 854, 974, 629, 144, 573, 110, 1130, 742, 1060, 775, 153, 186, 724, 301, 60, 924, 520, 505, 936, 538, 939, 579, 1037, 1058, 860, 1050, 461, 377, 807, 1064, 99, 256, 1057, 448, 115, 442, 78, 281, 917, 885, 557, 932, 400, 872, 701, 215, 747, 497, 388, 1034, 339, 342, 30, 919, 344, 670, 463, 365, 1003, 1016, 270, 780, 120, 1173, 93, 894, 187, 320, 663, 173, 261, 762, 636, 191, 140, 264, 194, 746, 260, 769, 880, 728, 235, 478, 964, 903, 74, 988, 809, 1025, 1143, 594, 1188, 1094, 736, 491, 758, 147, 879, 329, 1196, 861, 1138, 176, 205, 391, 490, 1164, 1083, 314, 222, 1132, 692, 598, 609, 938, 687, 302, 550, 237, 826, 887, 599, 150, 957, 273, 139, 190, 863, 453, 720, 1151, 509, 98, 166, 152, 410, 658, 1000, 495, 702, 529, 690, 593, 582, 802, 425, 164, 422, 240, 512, 776, 85, 823, 569, 358, 406, 613, 376, 931, 403, 716, 630, 465, 255, 661, 1163, 1030, 79, 184, 761, 332, 563, 962, 566, 102, 567, 109, 943, 732, 541, 523, 779, 37, 1123, 306, 592, 291, 713, 493, 1118, 117, 1192, 804, 933, 24, 183, 384, 675, 101, 792, 94, 896, 1070, 864, 241, 146, 892, 5, 995, 1127, 105, 544, 577, 1174, 751, 111, 1040, 385, 542, 1178, 749, 982, 707, 1011, 1069, 268, 1134, 1045, 333, 28, 133, 436, 229, 435, 868, 604, 711, 276, 548, 223, 996, 18, 909, 428, 1191, 430, 869, 719, 778, 508, 1102, 381, 937, 441, 993, 44, 820, 651, 1082, 1032, 926, 665, 618, 471, 382, 1068, 182, 580, 81, 814, 389, 740, 733, 3, 568, 130, 126, 438, 336, 195, 271, 369, 311, 600, 398, 662, 395, 359, 1116, 322, 160, 323, 501, 1066, 527, 88, 729, 1126, 985, 494, 70, 902, 127, 47, 1136, 1168, 56, 877, 1, 839, 795, 330, 484, 52, 433, 532, 1106, 632, 275, 1137, 1012, 774, 51, 409, 1091, 584, 643, 499, 7, 843, 1175, 1022, 1165, 280, 810, 245, 958, 967, 836, 908, 547, 125, 202, 226, 360, 62, 801, 831, 997, 553, 989, 258, 128, 1183, 116, 626, 1111, 54, 1001, 549, 537, 20, 980, 244, 307, 515, 1023, 1077, 984, 492, 13, 50, 590, 440, 921, 904, 918, 203, 946, 524 ] } ================================================ FILE: datainfo/fire/data_split_dict_1.json ================================================ { "test": [ 248, 491, 35, 873, 603, 402, 517, 866, 511, 903, 601, 290, 310, 194, 686, 849, 519, 951, 512, 728, 968, 917, 340, 760, 123, 303, 880, 741, 644, 190, 102, 657, 569, 857, 426, 960, 296, 470, 989, 224, 693, 236, 353, 238, 566, 990, 448, 994, 227, 540, 980, 743, 432, 221, 702, 390, 902, 9, 554, 665, 26, 22, 31, 325, 923, 104, 605, 234, 995, 738, 272, 456, 712, 2, 785, 780, 622, 443, 399, 855, 914, 29, 499, 96, 214, 807, 388, 667, 483, 460, 779, 507, 120, 261, 64, 782, 821, 867, 582, 137 ], "val": [ 992, 564, 93, 185, 598, 563, 181, 650, 235, 28, 614, 469, 339, 627, 805, 638, 553, 551, 1, 354, 895, 365, 496, 423, 516, 856, 567, 583, 373, 492, 57, 436, 210, 574, 796, 531, 132, 828, 524, 758, 802, 392, 5, 746, 623, 854, 675, 275, 936, 361, 591, 352, 526, 414, 237, 891, 552, 204, 789, 12, 232, 514, 172, 174, 662, 403, 592, 607, 629, 720, 315, 44, 480, 30, 750, 501, 379, 904, 860, 533, 167, 797, 110, 520, 679, 449, 88, 383, 755, 690, 794, 719, 561, 375, 177, 680, 424, 413, 816, 761 ], "train": [ 852, 836, 280, 575, 208, 609, 931, 725, 47, 984, 337, 581, 155, 692, 862, 790, 572, 241, 991, 788, 913, 498, 107, 767, 839, 930, 300, 386, 490, 920, 49, 216, 886, 73, 125, 565, 262, 717, 596, 669, 309, 803, 876, 848, 580, 864, 447, 484, 418, 147, 556, 323, 37, 586, 438, 409, 52, 599, 687, 878, 716, 243, 131, 562, 219, 95, 808, 82, 696, 306, 901, 144, 998, 415, 548, 4, 704, 707, 143, 269, 879, 479, 284, 888, 950, 294, 949, 13, 705, 515, 751, 356, 121, 621, 872, 894, 982, 472, 477, 863, 954, 295, 729, 378, 641, 753, 58, 798, 371, 271, 973, 169, 239, 207, 358, 955, 600, 509, 434, 585, 34, 634, 433, 311, 701, 475, 331, 473, 615, 7, 421, 266, 56, 384, 183, 391, 870, 200, 100, 59, 494, 893, 146, 685, 396, 892, 316, 176, 180, 834, 766, 19, 806, 68, 795, 453, 231, 616, 53, 246, 912, 709, 814, 23, 336, 570, 811, 481, 113, 168, 762, 362, 697, 730, 543, 25, 289, 197, 11, 987, 628, 233, 734, 655, 242, 276, 77, 915, 698, 444, 594, 67, 89, 393, 154, 673, 417, 299, 726, 199, 510, 39, 206, 530, 942, 446, 482, 525, 768, 635, 267, 370, 979, 678, 837, 966, 559, 842, 568, 293, 877, 109, 148, 99, 381, 345, 134, 544, 244, 135, 590, 158, 459, 368, 838, 883, 366, 405, 945, 425, 926, 910, 865, 940, 50, 343, 186, 815, 784, 529, 184, 642, 478, 338, 254, 817, 285, 677, 777, 829, 550, 874, 824, 981, 179, 749, 226, 927, 0, 633, 357, 257, 223, 493, 537, 410, 695, 69, 971, 905, 6, 660, 666, 55, 700, 380, 350, 631, 62, 215, 534, 182, 487, 653, 742, 360, 875, 139, 708, 451, 75, 145, 419, 699, 706, 124, 15, 882, 474, 630, 412, 97, 972, 394, 497, 127, 142, 151, 868, 209, 613, 307, 454, 61, 108, 527, 318, 495, 825, 166, 377, 422, 906, 195, 963, 947, 654, 869, 781, 539, 652, 51, 321, 130, 953, 268, 733, 278, 612, 476, 196, 178, 326, 201, 640, 401, 959, 291, 757, 648, 452, 79, 455, 193, 884, 658, 535, 63, 43, 304, 853, 840, 359, 84, 928, 282, 810, 916, 152, 159, 964, 385, 81, 188, 676, 643, 745, 770, 786, 437, 140, 334, 372, 312, 286, 776, 211, 737, 727, 595, 407, 253, 775, 122, 115, 374, 571, 898, 450, 332, 844, 588, 270, 90, 3, 440, 119, 45, 715, 944, 671, 674, 649, 141, 819, 793, 429, 900, 683, 70, 935, 813, 818, 791, 956, 292, 213, 330, 887, 602, 899, 835, 87, 651, 428, 202, 841, 542, 467, 435, 938, 277, 624, 281, 661, 896, 731, 843, 32, 398, 129, 126, 341, 441, 820, 764, 80, 787, 952, 153, 506, 489, 941, 382, 430, 606, 464, 344, 851, 763, 462, 502, 161, 541, 16, 549, 997, 774, 503, 925, 457, 943, 625, 308, 847, 427, 263, 363, 54, 156, 937, 688, 420, 523, 754, 397, 546, 885, 722, 486, 260, 619, 538, 513, 532, 157, 558, 505, 367, 946, 573, 934, 76, 714, 647, 812, 723, 249, 20, 744, 250, 94, 103, 342, 251, 911, 632, 191, 670, 42, 978, 682, 859, 986, 92, 458, 91, 932, 735, 369, 252, 203, 314, 610, 957, 187, 265, 364, 871, 636, 220, 256, 114, 466, 324, 65, 993, 408, 320, 400, 988, 327, 908, 10, 27, 597, 962, 212, 929, 322, 488, 756, 617, 771, 555, 752, 445, 105, 710, 247, 974, 732, 118, 165, 922, 245, 759, 996, 608, 822, 801, 792, 858, 611, 46, 881, 283, 468, 975, 659, 577, 302, 827, 273, 83, 579, 229, 804, 584, 713, 739, 909, 117, 349, 718, 150, 389, 74, 416, 40, 724, 684, 800, 593, 668, 500, 618, 656, 298, 765, 348, 346, 376, 778, 740, 809, 921, 274, 958, 897, 170, 173, 136, 86, 41, 66, 240, 545, 967, 547, 783, 560, 985, 98, 969, 218, 439, 347, 138, 576, 335, 939, 160, 748, 288, 411, 626, 333, 889, 907, 823, 924, 977, 681, 106, 504, 198, 965, 976, 587, 831, 101, 355, 205, 387, 703, 521, 637, 175, 471, 826, 222, 604, 38, 832, 8, 133, 258, 578, 933, 162, 769, 317, 78, 833, 313, 48, 217, 128, 305, 60, 919, 646, 845, 328, 589, 691, 404, 961, 664, 536, 228, 461, 528, 711, 645, 225, 557, 830, 694, 518, 721, 970, 164, 736, 36, 149, 406, 18, 230, 21, 442, 620, 983, 522, 747, 259, 111, 264, 192, 431, 351, 395, 319, 24, 116, 485, 508, 329, 890, 465, 301, 918, 663, 279, 672, 861, 948, 799, 163, 171, 71, 297, 850, 189, 639, 112, 846, 255, 287, 773, 772, 14, 463, 17, 85, 72, 689, 33 ] } ================================================ FILE: datainfo/fire/data_split_dict_2.json ================================================ { "test": [ 465, 677, 921, 956, 851, 527, 512, 510, 285, 501, 255, 543, 670, 472, 756, 732, 409, 772, 165, 930, 879, 370, 362, 607, 808, 966, 781, 537, 752, 424, 815, 456, 915, 186, 947, 690, 526, 368, 938, 522, 139, 177, 332, 180, 24, 236, 241, 181, 573, 168, 538, 905, 913, 433, 389, 929, 326, 954, 476, 372, 28, 891, 979, 922, 274, 514, 455, 958, 557, 380, 521, 880, 740, 822, 402, 653, 441, 162, 697, 595, 36, 621, 217, 620, 257, 315, 874, 685, 828, 753, 173, 855, 369, 86, 93, 57, 869, 970, 883, 978 ], "val": [ 4, 37, 155, 253, 44, 603, 394, 1, 708, 535, 188, 927, 160, 130, 261, 382, 21, 746, 41, 25, 69, 117, 84, 943, 688, 909, 176, 936, 371, 58, 32, 777, 734, 952, 61, 215, 250, 272, 939, 534, 916, 849, 698, 232, 605, 279, 50, 668, 588, 60, 108, 762, 195, 887, 8, 895, 349, 840, 803, 77, 638, 700, 375, 524, 500, 212, 748, 319, 416, 602, 630, 667, 519, 530, 575, 516, 903, 723, 818, 310, 316, 491, 791, 963, 631, 170, 990, 716, 834, 941, 227, 674, 498, 467, 741, 570, 743, 581, 359, 912 ], "train": [ 290, 214, 112, 544, 695, 541, 792, 671, 658, 339, 298, 644, 820, 496, 92, 289, 206, 211, 997, 218, 856, 847, 988, 62, 844, 22, 559, 507, 266, 643, 793, 870, 866, 542, 432, 300, 471, 647, 99, 627, 64, 554, 379, 245, 152, 629, 276, 127, 470, 251, 810, 482, 308, 579, 560, 307, 928, 505, 683, 153, 711, 750, 357, 539, 66, 710, 701, 341, 513, 457, 146, 492, 111, 485, 348, 529, 574, 959, 469, 322, 759, 356, 552, 830, 43, 299, 120, 91, 398, 649, 948, 648, 968, 365, 645, 463, 443, 440, 682, 373, 865, 737, 733, 873, 662, 924, 216, 897, 651, 158, 454, 805, 839, 42, 660, 282, 2, 417, 745, 993, 894, 387, 229, 736, 821, 189, 198, 556, 106, 105, 899, 70, 517, 950, 466, 515, 867, 508, 576, 151, 807, 614, 9, 582, 680, 779, 228, 864, 249, 580, 831, 405, 10, 563, 964, 49, 488, 878, 687, 448, 871, 301, 140, 626, 434, 747, 378, 311, 284, 312, 377, 361, 911, 499, 624, 525, 80, 101, 654, 283, 715, 390, 87, 549, 295, 306, 325, 547, 461, 600, 26, 487, 817, 980, 785, 744, 795, 100, 854, 439, 532, 616, 969, 427, 798, 13, 901, 264, 410, 286, 63, 29, 945, 33, 892, 797, 355, 726, 618, 135, 742, 350, 133, 208, 273, 806, 294, 51, 243, 15, 138, 309, 178, 890, 7, 991, 35, 123, 661, 659, 328, 545, 219, 12, 376, 166, 985, 882, 269, 967, 406, 354, 6, 344, 56, 584, 555, 129, 331, 612, 156, 436, 438, 758, 729, 848, 340, 925, 55, 587, 231, 951, 197, 210, 601, 842, 205, 802, 450, 813, 782, 278, 976, 835, 242, 611, 479, 116, 351, 569, 59, 415, 974, 511, 692, 889, 794, 735, 90, 73, 565, 190, 172, 23, 122, 126, 923, 920, 767, 646, 423, 934, 919, 128, 169, 787, 597, 48, 914, 665, 693, 594, 858, 900, 145, 975, 199, 714, 385, 363, 277, 717, 223, 333, 931, 183, 933, 179, 53, 850, 65, 425, 568, 780, 622, 244, 225, 613, 446, 709, 452, 202, 421, 809, 800, 343, 955, 318, 836, 396, 770, 360, 827, 712, 804, 755, 684, 304, 623, 778, 175, 761, 193, 431, 281, 137, 419, 490, 71, 608, 94, 191, 347, 97, 566, 673, 845, 81, 167, 837, 110, 327, 397, 786, 453, 944, 149, 460, 572, 85, 679, 713, 504, 447, 182, 163, 109, 704, 194, 775, 637, 699, 994, 789, 142, 267, 641, 287, 234, 119, 705, 918, 45, 321, 258, 408, 76, 271, 305, 200, 672, 324, 632, 47, 430, 604, 201, 329, 334, 578, 407, 330, 884, 801, 606, 550, 132, 39, 751, 596, 696, 635, 247, 591, 888, 681, 776, 346, 445, 707, 226, 288, 957, 823, 615, 838, 187, 0, 239, 719, 617, 739, 79, 593, 486, 640, 875, 824, 546, 27, 819, 694, 442, 592, 403, 540, 506, 766, 207, 724, 480, 384, 388, 590, 518, 906, 314, 40, 14, 610, 983, 583, 136, 235, 171, 567, 881, 473, 248, 302, 386, 67, 656, 3, 358, 74, 209, 478, 940, 551, 493, 689, 853, 141, 908, 796, 577, 204, 586, 706, 220, 395, 221, 910, 17, 678, 946, 562, 896, 977, 320, 11, 391, 412, 192, 774, 83, 338, 876, 483, 666, 816, 414, 280, 494, 54, 937, 763, 184, 392, 30, 451, 477, 148, 367, 413, 240, 898, 161, 989, 263, 935, 254, 528, 336, 764, 720, 553, 159, 114, 293, 536, 825, 449, 337, 962, 399, 609, 131, 38, 88, 68, 426, 150, 558, 942, 364, 72, 893, 589, 599, 992, 437, 561, 783, 790, 811, 771, 634, 303, 503, 754, 489, 718, 495, 275, 383, 144, 468, 366, 297, 509, 857, 481, 721, 664, 691, 400, 342, 728, 125, 633, 260, 124, 404, 353, 749, 655, 401, 814, 585, 118, 196, 502, 224, 78, 663, 669, 768, 291, 5, 730, 628, 868, 213, 474, 497, 652, 420, 833, 462, 16, 203, 313, 702, 103, 422, 675, 788, 381, 296, 861, 625, 121, 531, 19, 374, 996, 96, 185, 841, 926, 428, 52, 727, 998, 902, 523, 435, 444, 852, 981, 953, 657, 548, 317, 971, 475, 986, 533, 877, 222, 107, 738, 932, 20, 102, 769, 246, 89, 862, 113, 164, 418, 393, 961, 154, 799, 949, 907, 832, 860, 262, 826, 859, 639, 233, 843, 82, 256, 75, 965, 731, 238, 252, 829, 725, 520, 237, 650, 464, 98, 174, 917, 134, 34, 259, 987, 686, 143, 571, 886, 18, 846, 429, 982, 268, 265, 885, 147, 335, 904, 960, 973, 598, 872, 812, 458, 972, 104, 323, 703, 676, 95, 230, 484, 157, 722, 636, 411, 773, 270, 46, 757, 619, 784, 564, 459, 984, 31, 863, 345, 292, 115, 765, 760, 642, 995, 352 ] } ================================================ FILE: datainfo/fire/data_split_dict_3.json ================================================ { "test": [ 712, 959, 987, 286, 876, 29, 698, 344, 968, 598, 417, 599, 546, 359, 587, 395, 853, 519, 431, 952, 875, 641, 797, 446, 688, 264, 222, 506, 139, 36, 99, 374, 943, 137, 455, 590, 820, 745, 404, 437, 807, 731, 396, 899, 942, 736, 609, 484, 275, 886, 843, 31, 798, 308, 43, 605, 776, 163, 65, 795, 687, 15, 759, 399, 535, 948, 888, 155, 650, 237, 154, 881, 654, 406, 487, 562, 856, 553, 481, 734, 196, 239, 564, 265, 480, 857, 930, 13, 620, 67, 594, 640, 485, 618, 937, 378, 133, 557, 606, 243 ], "val": [ 603, 630, 830, 870, 322, 593, 866, 11, 835, 561, 310, 533, 924, 447, 918, 998, 732, 737, 649, 441, 277, 916, 635, 105, 572, 697, 945, 659, 914, 532, 471, 385, 859, 141, 368, 321, 347, 953, 706, 159, 269, 625, 298, 909, 202, 32, 548, 614, 110, 78, 7, 317, 928, 241, 517, 285, 981, 338, 600, 735, 386, 46, 779, 629, 619, 45, 121, 425, 787, 938, 300, 20, 969, 420, 68, 819, 352, 90, 495, 971, 874, 493, 64, 127, 291, 273, 913, 851, 648, 216, 671, 730, 106, 582, 585, 554, 334, 970, 715, 167 ], "train": [ 316, 973, 262, 342, 904, 898, 566, 511, 871, 143, 176, 910, 644, 290, 927, 988, 624, 185, 370, 720, 156, 346, 388, 313, 616, 148, 891, 371, 993, 567, 832, 249, 272, 415, 529, 592, 472, 513, 749, 261, 664, 844, 848, 538, 236, 926, 847, 25, 777, 89, 147, 460, 873, 401, 219, 209, 115, 288, 684, 578, 826, 457, 174, 267, 576, 152, 432, 966, 362, 686, 520, 142, 746, 668, 145, 470, 98, 834, 709, 449, 908, 884, 192, 312, 889, 718, 71, 950, 690, 38, 974, 210, 704, 218, 421, 855, 380, 658, 536, 42, 303, 541, 762, 456, 207, 450, 402, 766, 579, 259, 551, 413, 412, 767, 982, 309, 198, 583, 166, 550, 915, 2, 922, 773, 393, 846, 179, 639, 217, 785, 814, 377, 713, 983, 500, 256, 390, 40, 479, 365, 355, 699, 733, 66, 476, 238, 463, 758, 491, 464, 114, 337, 328, 540, 925, 655, 893, 112, 738, 228, 466, 552, 279, 771, 53, 108, 199, 17, 865, 225, 100, 41, 701, 201, 742, 10, 907, 80, 911, 788, 178, 407, 315, 16, 627, 266, 190, 301, 8, 242, 661, 503, 235, 177, 443, 84, 504, 394, 790, 944, 253, 19, 622, 233, 521, 507, 6, 12, 809, 783, 646, 119, 725, 434, 168, 319, 144, 565, 113, 165, 366, 292, 9, 663, 251, 601, 350, 757, 414, 392, 756, 636, 645, 59, 462, 193, 197, 693, 0, 281, 186, 48, 314, 638, 405, 287, 220, 424, 555, 282, 794, 739, 364, 901, 162, 158, 900, 379, 208, 689, 985, 840, 324, 505, 103, 325, 716, 743, 672, 74, 957, 302, 213, 96, 545, 571, 990, 478, 442, 77, 837, 991, 526, 813, 633, 595, 206, 170, 770, 468, 559, 710, 518, 628, 861, 474, 760, 979, 138, 452, 863, 558, 363, 294, 490, 283, 92, 356, 408, 810, 824, 422, 234, 418, 153, 284, 215, 774, 769, 72, 902, 55, 63, 681, 768, 91, 967, 611, 391, 975, 150, 30, 180, 260, 252, 58, 839, 274, 289, 858, 890, 754, 805, 339, 729, 727, 188, 939, 748, 919, 486, 653, 764, 224, 57, 221, 793, 792, 140, 796, 956, 278, 896, 200, 574, 453, 667, 761, 39, 248, 194, 573, 169, 171, 960, 989, 172, 935, 842, 613, 231, 670, 933, 149, 135, 621, 60, 589, 373, 675, 965, 93, 610, 416, 86, 516, 575, 821, 958, 318, 509, 992, 931, 410, 97, 666, 596, 451, 674, 510, 397, 367, 331, 120, 660, 702, 23, 375, 400, 696, 862, 679, 246, 14, 719, 895, 872, 427, 703, 980, 651, 534, 607, 812, 977, 214, 570, 941, 483, 164, 955, 780, 612, 741, 528, 604, 411, 920, 932, 806, 157, 683, 497, 489, 482, 458, 304, 591, 868, 789, 343, 151, 962, 631, 502, 569, 118, 161, 263, 299, 75, 586, 996, 921, 680, 95, 439, 833, 869, 951, 496, 357, 469, 617, 563, 632, 49, 83, 76, 454, 459, 205, 740, 673, 685, 351, 811, 473, 345, 296, 387, 963, 744, 212, 82, 877, 721, 711, 211, 560, 498, 240, 543, 85, 577, 358, 972, 376, 652, 403, 465, 255, 525, 247, 79, 184, 332, 705, 102, 109, 662, 523, 37, 903, 306, 883, 880, 117, 26, 808, 878, 24, 183, 384, 101, 94, 800, 864, 954, 852, 146, 656, 5, 822, 544, 326, 111, 801, 724, 597, 707, 124, 268, 383, 333, 28, 995, 436, 816, 435, 608, 276, 845, 940, 676, 18, 885, 428, 430, 765, 508, 381, 717, 818, 44, 804, 894, 829, 382, 836, 182, 580, 978, 389, 3, 568, 130, 126, 438, 336, 867, 271, 369, 311, 984, 398, 827, 912, 73, 722, 160, 323, 784, 514, 527, 88, 923, 494, 70, 897, 882, 47, 129, 56, 1, 330, 946, 52, 433, 828, 539, 751, 254, 708, 51, 409, 584, 499, 849, 131, 280, 245, 815, 677, 547, 125, 949, 226, 360, 781, 747, 976, 258, 128, 682, 116, 626, 54, 905, 549, 537, 802, 750, 763, 515, 175, 726, 492, 986, 50, 934, 440, 203, 524, 35, 860, 250, 122, 293, 602, 69, 232, 349, 556, 295, 531, 678, 775, 33, 753, 823, 444, 522, 637, 488, 4, 204, 838, 423, 964, 419, 581, 429, 297, 426, 817, 354, 475, 728, 530, 841, 917, 132, 782, 353, 634, 87, 657, 348, 786, 187, 588, 803, 195, 542, 34, 123, 230, 879, 461, 961, 223, 936, 906, 81, 173, 448, 229, 691, 341, 329, 772, 104, 929, 714, 665, 445, 694, 191, 335, 244, 947, 669, 227, 512, 850, 134, 752, 700, 892, 27, 107, 831, 307, 270, 825, 778, 320, 189, 372, 181, 327, 615, 997, 305, 467, 643, 257, 994, 21, 692, 62, 799, 22, 501, 755, 854, 723, 623, 791, 695, 361, 477, 340, 642, 887, 61, 136, 647 ] } ================================================ FILE: datainfo/lava_lamp/data_split_dict_1.json ================================================ { "test": [ 408, 380, 248, 491, 35, 402, 511, 290, 310, 194, 519, 539, 512, 340, 123, 303, 190, 102, 426, 544, 296, 470, 224, 236, 353, 238, 561, 448, 227, 554, 432, 221, 390, 9, 26, 22, 31, 325, 104, 234, 272, 456, 2, 443, 399, 29, 499, 96, 214, 388, 483, 460, 507, 120, 261, 64, 137 ], "val": [ 463, 337, 565, 235, 180, 295, 433, 176, 263, 207, 118, 503, 440, 276, 526, 394, 6, 116, 257, 86, 87, 331, 487, 529, 524, 314, 434, 360, 157, 528, 240, 15, 375, 250, 189, 201, 430, 266, 83, 398, 55, 260, 339, 531, 44, 191, 377, 345, 397, 359, 451, 280, 187, 88, 522, 212, 206 ], "train": [ 126, 245, 516, 84, 527, 351, 124, 165, 462, 504, 393, 447, 352, 54, 385, 251, 155, 52, 327, 160, 244, 404, 498, 453, 392, 67, 396, 131, 76, 21, 475, 163, 365, 274, 154, 422, 110, 413, 128, 558, 217, 184, 241, 362, 17, 195, 464, 279, 326, 147, 89, 249, 496, 216, 59, 410, 342, 103, 106, 320, 73, 505, 378, 381, 457, 334, 142, 318, 72, 530, 341, 292, 482, 545, 304, 253, 537, 161, 38, 98, 239, 523, 79, 95, 316, 284, 145, 192, 285, 172, 135, 247, 409, 13, 324, 555, 478, 328, 209, 374, 133, 65, 119, 357, 113, 298, 162, 368, 233, 489, 60, 461, 45, 101, 306, 370, 1, 223, 277, 121, 226, 125, 497, 220, 371, 288, 51, 552, 5, 541, 134, 495, 185, 267, 62, 77, 550, 47, 490, 178, 536, 407, 335, 494, 210, 506, 481, 71, 348, 406, 63, 476, 78, 181, 367, 405, 317, 91, 256, 486, 333, 361, 458, 435, 469, 153, 308, 168, 418, 286, 183, 100, 198, 179, 343, 450, 449, 468, 508, 321, 502, 297, 25, 315, 122, 425, 329, 500, 376, 138, 415, 312, 301, 525, 40, 204, 560, 553, 354, 208, 41, 358, 532, 338, 61, 48, 421, 444, 222, 417, 23, 557, 171, 237, 3, 225, 428, 11, 243, 27, 70, 412, 534, 452, 229, 75, 93, 269, 271, 389, 57, 140, 146, 442, 363, 200, 484, 480, 538, 309, 32, 424, 37, 474, 300, 97, 350, 547, 355, 349, 543, 366, 141, 467, 34, 485, 445, 149, 369, 174, 431, 188, 382, 347, 170, 391, 509, 384, 510, 68, 43, 20, 33, 562, 549, 427, 273, 387, 518, 49, 136, 109, 219, 173, 69, 167, 305, 80, 533, 144, 205, 166, 439, 255, 429, 151, 199, 53, 252, 99, 356, 293, 437, 50, 423, 488, 193, 477, 493, 472, 455, 520, 111, 302, 19, 446, 4, 441, 479, 289, 213, 383, 330, 158, 39, 400, 156, 24, 108, 564, 152, 30, 346, 323, 372, 294, 202, 559, 332, 268, 114, 230, 264, 322, 112, 278, 436, 259, 228, 82, 18, 74, 203, 542, 115, 10, 540, 517, 107, 563, 129, 492, 132, 556, 215, 175, 197, 159, 12, 58, 242, 254, 164, 501, 232, 150, 364, 473, 139, 336, 471, 270, 403, 81, 85, 513, 148, 459, 94, 419, 56, 454, 127, 143, 395, 386, 7, 231, 8, 42, 36, 344, 16, 130, 282, 46, 92, 299, 281, 90, 546, 117, 411, 14, 307, 548, 169, 313, 514, 319, 465, 275, 0, 177, 535, 182, 416, 515, 211, 258, 466, 283, 291, 186, 246, 28, 218, 105, 287, 521, 265, 66, 414, 262, 379, 420, 438, 401, 196, 551, 373, 311 ] } ================================================ FILE: datainfo/lava_lamp/data_split_dict_2.json ================================================ { "test": [ 462, 460, 465, 523, 512, 510, 285, 501, 255, 472, 409, 165, 528, 370, 362, 546, 424, 456, 186, 526, 368, 531, 522, 139, 177, 332, 180, 24, 236, 241, 181, 168, 538, 433, 389, 326, 476, 372, 28, 557, 274, 514, 455, 380, 521, 402, 441, 162, 36, 217, 257, 315, 173, 369, 86, 93, 57 ], "val": [ 30, 54, 381, 97, 495, 4, 505, 174, 420, 401, 38, 451, 319, 350, 187, 262, 250, 106, 374, 159, 208, 301, 489, 333, 259, 265, 287, 258, 425, 361, 519, 155, 158, 245, 466, 395, 137, 560, 464, 448, 85, 427, 358, 417, 166, 481, 113, 337, 249, 233, 530, 515, 471, 371, 290, 179, 537 ], "train": [ 89, 278, 542, 391, 303, 504, 223, 215, 320, 114, 559, 327, 377, 331, 488, 336, 76, 555, 14, 206, 539, 81, 96, 507, 153, 124, 150, 225, 470, 399, 284, 50, 307, 5, 33, 203, 509, 199, 541, 445, 341, 457, 218, 239, 273, 83, 335, 246, 19, 556, 322, 393, 224, 475, 387, 340, 385, 463, 352, 169, 357, 133, 252, 483, 90, 384, 230, 170, 508, 305, 411, 207, 244, 347, 296, 3, 79, 143, 421, 440, 423, 379, 189, 275, 221, 256, 408, 175, 243, 517, 458, 355, 121, 163, 183, 61, 55, 563, 272, 1, 188, 397, 110, 62, 228, 392, 182, 529, 254, 449, 102, 151, 482, 104, 375, 478, 498, 365, 443, 346, 324, 72, 312, 27, 141, 279, 496, 234, 253, 154, 160, 413, 410, 551, 144, 291, 439, 487, 506, 364, 292, 516, 105, 64, 235, 70, 220, 13, 415, 240, 46, 545, 7, 400, 339, 264, 394, 193, 103, 286, 311, 518, 43, 271, 329, 63, 454, 238, 195, 406, 138, 396, 547, 363, 219, 317, 297, 84, 407, 171, 398, 99, 280, 32, 281, 520, 212, 405, 511, 544, 293, 491, 45, 226, 323, 536, 147, 328, 149, 21, 202, 561, 178, 109, 140, 66, 152, 213, 40, 479, 342, 75, 513, 438, 300, 383, 527, 330, 122, 360, 316, 68, 95, 204, 548, 117, 91, 74, 426, 535, 429, 11, 198, 120, 540, 304, 6, 295, 533, 444, 100, 164, 492, 298, 432, 419, 247, 366, 502, 31, 412, 276, 493, 192, 35, 270, 200, 416, 59, 98, 251, 112, 39, 277, 283, 308, 145, 480, 499, 469, 306, 248, 210, 435, 231, 8, 101, 156, 359, 314, 211, 288, 390, 190, 148, 289, 60, 562, 356, 486, 485, 48, 92, 268, 494, 26, 313, 261, 558, 222, 436, 108, 194, 549, 484, 345, 237, 266, 524, 111, 53, 353, 564, 553, 51, 418, 123, 44, 467, 56, 348, 209, 196, 554, 403, 461, 269, 142, 434, 131, 428, 534, 490, 446, 447, 41, 128, 37, 227, 119, 404, 431, 260, 118, 325, 232, 49, 87, 82, 67, 17, 129, 430, 343, 71, 503, 9, 450, 214, 310, 134, 132, 459, 73, 167, 263, 500, 201, 299, 477, 414, 543, 550, 52, 161, 351, 338, 47, 115, 242, 78, 497, 318, 205, 135, 23, 378, 309, 282, 229, 157, 15, 468, 172, 146, 565, 382, 552, 321, 474, 176, 2, 18, 77, 126, 22, 473, 197, 0, 354, 442, 94, 376, 80, 65, 130, 191, 10, 373, 20, 525, 34, 58, 42, 12, 344, 127, 88, 184, 185, 29, 16, 388, 367, 216, 452, 107, 422, 125, 136, 437, 69, 267, 386, 453, 349, 116, 302, 532, 25, 334, 294 ] } ================================================ FILE: datainfo/lava_lamp/data_split_dict_3.json ================================================ { "test": [ 425, 324, 216, 106, 334, 167, 286, 29, 344, 549, 417, 359, 395, 519, 431, 541, 446, 264, 222, 506, 139, 36, 99, 374, 137, 455, 404, 437, 396, 484, 275, 31, 308, 43, 163, 65, 15, 399, 535, 155, 237, 154, 406, 487, 553, 481, 196, 239, 265, 480, 13, 67, 485, 378, 133, 557, 243 ], "val": [ 173, 444, 21, 353, 79, 134, 312, 149, 208, 101, 16, 274, 307, 55, 39, 3, 158, 18, 120, 258, 142, 472, 451, 282, 169, 300, 367, 193, 23, 389, 314, 309, 22, 60, 525, 212, 393, 218, 150, 10, 77, 459, 210, 34, 409, 176, 45, 247, 327, 536, 246, 32, 63, 145, 136, 293, 505 ], "train": [ 188, 28, 326, 171, 483, 422, 494, 319, 440, 305, 59, 355, 379, 299, 73, 383, 178, 54, 376, 331, 352, 454, 234, 421, 112, 42, 442, 529, 354, 445, 342, 364, 46, 187, 565, 126, 103, 436, 181, 104, 199, 35, 41, 450, 287, 191, 38, 449, 427, 539, 156, 526, 358, 270, 382, 517, 313, 388, 325, 85, 24, 555, 458, 242, 402, 466, 426, 272, 423, 37, 411, 443, 227, 292, 100, 175, 273, 298, 322, 470, 75, 56, 335, 91, 504, 551, 424, 384, 51, 119, 19, 231, 141, 500, 372, 74, 229, 88, 512, 205, 168, 520, 283, 279, 511, 202, 375, 71, 44, 269, 195, 47, 419, 50, 297, 351, 78, 276, 263, 248, 124, 410, 252, 105, 302, 157, 394, 476, 545, 435, 380, 509, 130, 9, 232, 76, 72, 236, 516, 362, 385, 387, 288, 206, 337, 471, 254, 336, 253, 194, 221, 561, 530, 260, 1, 330, 26, 528, 144, 81, 277, 96, 562, 524, 118, 320, 217, 143, 190, 83, 7, 33, 318, 57, 251, 225, 151, 69, 255, 465, 267, 4, 200, 460, 390, 107, 226, 179, 369, 381, 201, 398, 110, 159, 240, 162, 207, 284, 166, 428, 117, 363, 281, 498, 508, 20, 203, 338, 185, 183, 245, 478, 241, 556, 109, 84, 48, 371, 92, 492, 490, 531, 165, 98, 89, 414, 228, 349, 80, 295, 391, 131, 538, 182, 489, 8, 123, 507, 12, 339, 397, 503, 499, 14, 219, 0, 316, 198, 82, 429, 532, 108, 563, 457, 370, 430, 127, 412, 289, 204, 249, 467, 544, 140, 475, 285, 356, 62, 534, 306, 420, 521, 310, 129, 64, 477, 58, 27, 482, 463, 268, 488, 365, 546, 257, 87, 418, 502, 6, 25, 433, 461, 262, 340, 558, 125, 341, 146, 495, 116, 333, 278, 147, 496, 554, 462, 523, 439, 401, 261, 244, 2, 102, 456, 211, 469, 209, 290, 214, 148, 213, 177, 518, 343, 564, 493, 215, 66, 537, 497, 501, 542, 328, 174, 400, 93, 294, 522, 97, 271, 17, 61, 115, 434, 230, 373, 111, 360, 172, 40, 86, 224, 114, 345, 407, 164, 386, 438, 49, 357, 332, 527, 547, 95, 514, 122, 533, 513, 113, 256, 468, 560, 350, 291, 559, 53, 447, 153, 135, 448, 392, 474, 94, 186, 90, 543, 464, 303, 152, 233, 408, 128, 189, 416, 346, 540, 413, 11, 250, 377, 473, 361, 311, 405, 347, 180, 238, 170, 321, 432, 30, 68, 323, 301, 315, 486, 491, 161, 296, 552, 403, 5, 452, 280, 548, 453, 132, 223, 550, 121, 366, 368, 510, 220, 138, 259, 415, 317, 52, 515, 348, 304, 329, 197, 266, 235, 192, 479, 441, 70, 184, 160 ] } ================================================ FILE: datainfo/reaction_diffusion/data_split_dict_1.json ================================================ { "test": [ 83, 60, 57, 63, 15, 32, 8, 97, 72, 17 ], "val": [ 92, 0, 77, 55, 49, 3, 62, 12, 26, 48 ], "train": [ 53, 37, 65, 51, 4, 20, 38, 9, 10, 81, 44, 36, 84, 50, 96, 90, 66, 16, 80, 33, 24, 52, 91, 99, 64, 5, 58, 76, 39, 79, 23, 94, 30, 73, 25, 47, 31, 45, 19, 87, 42, 68, 95, 21, 7, 67, 46, 82, 11, 6, 41, 86, 88, 70, 18, 78, 71, 59, 43, 61, 22, 14, 35, 93, 56, 28, 98, 54, 27, 89, 1, 69, 74, 2, 85, 40, 13, 75, 29, 34 ] } ================================================ FILE: datainfo/reaction_diffusion/data_split_dict_2.json ================================================ { "test": [ 77, 32, 39, 85, 94, 21, 46, 10, 11, 7 ], "val": [ 47, 65, 50, 81, 55, 20, 74, 4, 90, 27 ], "train": [ 0, 76, 61, 63, 1, 71, 2, 6, 16, 19, 13, 24, 49, 12, 75, 9, 83, 72, 5, 41, 99, 45, 89, 53, 79, 18, 52, 92, 14, 42, 68, 44, 38, 84, 36, 17, 31, 15, 70, 88, 25, 97, 51, 73, 66, 37, 78, 33, 80, 26, 82, 28, 60, 35, 43, 57, 23, 58, 91, 8, 62, 93, 98, 86, 29, 30, 22, 95, 67, 54, 48, 40, 59, 96, 3, 87, 34, 64, 56, 69 ] } ================================================ FILE: datainfo/reaction_diffusion/data_split_dict_3.json ================================================ { "test": [ 8, 74, 80, 60, 77, 47, 16, 69, 75, 30 ], "val": [ 85, 97, 87, 24, 29, 70, 33, 93, 1, 94 ], "train": [ 35, 41, 45, 4, 76, 12, 0, 64, 42, 11, 81, 39, 7, 48, 78, 15, 53, 62, 9, 52, 68, 55, 63, 65, 61, 67, 95, 18, 58, 86, 99, 92, 10, 17, 72, 21, 14, 44, 37, 91, 22, 59, 83, 32, 26, 98, 40, 27, 43, 96, 13, 31, 57, 2, 6, 23, 56, 71, 28, 36, 51, 46, 25, 54, 73, 79, 34, 3, 38, 5, 20, 90, 88, 49, 66, 89, 84, 19, 50, 82 ] } ================================================ FILE: datainfo/single_pendulum/data_split_dict_1.json ================================================ { "test": [ 187, 370, 362, 470, 57, 938, 678, 3, 708, 1137, 730, 993, 846, 1033, 409, 746, 985, 114, 872, 420, 1062, 264, 1049, 785, 11, 551, 940, 723, 704, 1052, 828, 475, 1105, 408, 25, 464, 1028, 345, 348, 806, 631, 89, 961, 60, 1002, 758, 1142, 1066, 335, 221, 1041, 898, 177, 767, 1123, 751, 354, 848, 827, 497, 983, 70, 805, 1034, 1022, 581, 621, 388, 1039, 1171, 1025, 681, 247, 607, 380, 204, 1139, 852, 44, 593, 941, 448, 472, 707, 477, 1132, 1015, 896, 454, 1080, 59, 864, 443, 780, 18, 1108, 52, 45, 62, 650, 209, 468, 545, 912, 4, 886, 798, 58, 999, 192, 429, 777, 967, 920, 1014, 241, 522, 129, 1165, 275 ], "val": [ 411, 892, 626, 333, 17, 1071, 516, 303, 399, 1151, 960, 106, 504, 198, 605, 1172, 918, 690, 587, 210, 101, 355, 205, 387, 1000, 521, 637, 720, 1186, 890, 888, 847, 175, 471, 583, 922, 1096, 1046, 839, 604, 38, 870, 899, 574, 8, 133, 258, 578, 426, 162, 761, 305, 1122, 939, 317, 78, 879, 1036, 313, 1053, 1167, 217, 991, 257, 611, 120, 1093, 657, 808, 1178, 457, 923, 451, 873, 1183, 328, 72, 299, 813, 36, 461, 42, 884, 428, 1044, 519, 222, 529, 385, 862, 703, 791, 638, 48, 233, 970, 1016, 659, 931, 603, 558, 344, 1079, 326, 342, 142, 594, 705, 378, 224, 550, 511, 575, 29, 927, 34, 170, 144, 66, 1196 ], "train": [ 483, 489, 341, 685, 96, 1, 1148, 542, 656, 744, 1085, 613, 855, 775, 680, 111, 1081, 648, 308, 733, 1023, 189, 759, 61, 482, 639, 1190, 228, 776, 596, 885, 1024, 1187, 1048, 295, 1158, 1146, 1094, 819, 684, 540, 793, 856, 1099, 622, 262, 773, 609, 107, 263, 260, 1160, 958, 474, 672, 739, 442, 478, 887, 1083, 734, 612, 1084, 956, 843, 55, 1067, 750, 823, 9, 532, 112, 282, 329, 520, 874, 359, 1125, 1102, 514, 668, 1164, 459, 197, 844, 753, 635, 633, 185, 1195, 966, 201, 906, 1059, 569, 146, 580, 39, 1047, 446, 597, 276, 43, 1042, 533, 561, 166, 1003, 154, 778, 647, 243, 67, 768, 617, 1149, 634, 981, 649, 450, 507, 215, 1057, 740, 691, 710, 1120, 858, 889, 682, 71, 531, 869, 415, 849, 444, 237, 716, 214, 764, 735, 765, 595, 486, 90, 834, 989, 437, 292, 1127, 549, 330, 16, 304, 559, 1168, 104, 186, 916, 719, 99, 623, 476, 293, 498, 286, 1136, 905, 503, 143, 952, 917, 167, 242, 26, 797, 149, 315, 652, 897, 871, 234, 271, 223, 1029, 769, 238, 673, 924, 294, 1199, 567, 33, 207, 821, 945, 795, 895, 548, 784, 729, 987, 412, 908, 986, 965, 375, 427, 6, 891, 456, 195, 980, 641, 435, 419, 1004, 925, 721, 538, 80, 718, 762, 679, 1063, 131, 937, 311, 307, 552, 485, 1054, 1043, 959, 1176, 840, 1037, 432, 699, 1177, 812, 424, 1157, 752, 360, 190, 68, 316, 693, 1194, 383, 384, 951, 517, 741, 1110, 1072, 671, 676, 219, 179, 1121, 30, 396, 1106, 1005, 585, 402, 321, 954, 943, 606, 749, 255, 859, 530, 865, 390, 395, 976, 877, 556, 268, 1007, 755, 957, 909, 702, 933, 509, 929, 130, 582, 289, 1073, 277, 756, 973, 1100, 394, 525, 81, 598, 100, 147, 379, 196, 644, 851, 2, 782, 351, 824, 1090, 56, 995, 586, 636, 919, 660, 108, 199, 1045, 280, 1091, 1107, 553, 757, 421, 193, 655, 653, 119, 494, 492, 1019, 675, 343, 1134, 140, 339, 692, 453, 1018, 438, 1163, 964, 571, 695, 640, 28, 696, 95, 51, 510, 164, 493, 910, 54, 1188, 156, 911, 998, 1170, 123, 838, 158, 184, 760, 401, 371, 1141, 974, 269, 235, 134, 770, 75, 372, 747, 227, 296, 1032, 374, 1173, 487, 97, 434, 127, 893, 1129, 841, 206, 591, 15, 338, 125, 361, 619, 694, 113, 366, 599, 413, 469, 404, 842, 88, 1155, 787, 200, 1092, 1068, 738, 714, 463, 630, 543, 913, 334, 152, 481, 1114, 645, 285, 462, 854, 188, 337, 267, 425, 357, 984, 717, 137, 1031, 121, 947, 270, 1131, 473, 484, 1086, 1013, 836, 570, 47, 353, 82, 1088, 832, 84, 501, 226, 336, 1075, 666, 35, 216, 781, 508, 50, 465, 491, 23, 763, 669, 926, 674, 132, 687, 236, 77, 155, 278, 350, 1058, 715, 1082, 534, 135, 467, 178, 5, 309, 397, 417, 37, 754, 670, 433, 902, 815, 524, 479, 544, 664, 79, 701, 789, 528, 748, 829, 148, 332, 495, 904, 539, 381, 745, 790, 69, 1061, 414, 624, 398, 727, 368, 202, 537, 174, 535, 590, 1162, 1051, 643, 978, 1087, 642, 602, 915, 115, 972, 779, 1038, 731, 358, 816, 452, 850, 0, 331, 1145, 833, 515, 447, 568, 382, 663, 1135, 139, 1192, 13, 589, 689, 867, 290, 1180, 392, 213, 124, 430, 502, 365, 722, 172, 935, 651, 318, 279, 955, 950, 151, 291, 736, 724, 266, 248, 31, 21, 883, 665, 194, 1184, 572, 837, 254, 284, 820, 1150, 418, 546, 441, 122, 1021, 499, 725, 662, 963, 809, 406, 712, 391, 246, 455, 1124, 1065, 85, 875, 928, 914, 783, 536, 1074, 766, 168, 527, 393, 356, 403, 1027, 53, 994, 997, 436, 700, 159, 180, 386, 860, 831, 620, 126, 87, 608, 1111, 169, 1154, 565, 713, 1011, 319, 208, 646, 163, 377, 523, 73, 709, 1069, 281, 625, 573, 1117, 422, 230, 490, 506, 814, 244, 231, 654, 1161, 512, 772, 541, 181, 7, 944, 239, 932, 627, 176, 410, 1098, 407, 868, 802, 737, 907, 1153, 141, 901, 363, 63, 19, 880, 1008, 661, 24, 1156, 1101, 992, 1143, 405, 1104, 1115, 711, 449, 794, 1119, 562, 440, 1030, 698, 811, 1006, 253, 683, 49, 161, 1070, 975, 306, 182, 979, 706, 566, 688, 1109, 323, 32, 1060, 145, 211, 1197, 300, 616, 526, 726, 109, 312, 817, 1077, 153, 183, 1198, 969, 996, 881, 774, 513, 1133, 157, 799, 505, 367, 297, 822, 1179, 22, 76, 1095, 1017, 900, 1193, 894, 1055, 249, 20, 225, 250, 94, 818, 1103, 962, 557, 103, 1064, 251, 310, 592, 810, 191, 953, 1130, 792, 968, 232, 948, 658, 588, 826, 92, 458, 771, 91, 287, 876, 369, 252, 203, 314, 1138, 554, 1169, 265, 364, 677, 480, 1175, 835, 796, 632, 803, 220, 256, 1097, 466, 615, 324, 65, 64, 1113, 320, 400, 460, 327, 610, 743, 863, 866, 10, 27, 853, 325, 667, 212, 102, 322, 488, 728, 259, 1056, 301, 555, 825, 882, 445, 105, 1010, 1009, 1152, 697, 171, 1040, 118, 165, 431, 600, 804, 245, 1189, 1020, 1116, 845, 1001, 423, 93, 14, 686, 628, 12, 1026, 1144, 46, 1126, 110, 283, 1181, 1012, 934, 577, 302, 373, 273, 83, 579, 229, 563, 584, 1166, 1140, 800, 601, 629, 117, 349, 128, 1089, 150, 807, 1185, 389, 74, 416, 40, 1035, 971, 788, 564, 1159, 949, 500, 732, 988, 618, 990, 930, 298, 116, 1118, 346, 376, 261, 861, 518, 614, 340, 1191, 274, 946, 1112, 1076, 173, 136, 86, 41, 742, 1078, 240, 1182, 786, 496, 547, 1050, 942, 903, 936, 352, 560, 1147, 857, 98, 977, 272, 218, 439, 347, 138, 801, 576, 830, 1128, 878, 982, 160, 1174, 288, 921 ] } ================================================ FILE: datainfo/single_pendulum/data_split_dict_2.json ================================================ { "test": [ 789, 3, 1071, 376, 321, 261, 523, 764, 43, 83, 51, 138, 235, 169, 1167, 510, 352, 737, 742, 116, 65, 866, 123, 431, 501, 544, 1163, 1069, 1113, 464, 559, 100, 120, 217, 391, 17, 699, 154, 750, 1048, 1001, 425, 638, 832, 1039, 1060, 1032, 621, 633, 982, 1181, 340, 664, 454, 996, 935, 718, 1171, 931, 1152, 1055, 1025, 1020, 571, 1003, 511, 1086, 944, 818, 330, 1156, 741, 724, 1178, 1075, 849, 912, 372, 1146, 1052, 736, 1162, 1044, 279, 355, 665, 361, 48, 472, 483, 363, 1147, 336, 1076, 867, 778, 652, 952, 745, 56, 1191, 549, 1028, 911, 1114, 761, 1042, 805, 882, 324, 1190, 73, 434, 515, 631, 346, 739, 173, 187, 115 ], "val": [ 954, 706, 296, 375, 625, 121, 943, 531, 19, 374, 491, 821, 679, 96, 942, 185, 983, 537, 951, 428, 902, 52, 605, 595, 21, 1158, 435, 444, 825, 746, 215, 986, 1175, 1037, 701, 1108, 389, 657, 548, 317, 1109, 475, 685, 533, 25, 222, 107, 237, 768, 186, 20, 102, 104, 246, 89, 1151, 524, 113, 1029, 418, 393, 1046, 1117, 9, 570, 1101, 525, 1160, 467, 164, 513, 150, 910, 476, 505, 64, 474, 928, 196, 349, 1149, 269, 68, 518, 1099, 287, 36, 859, 536, 530, 698, 294, 671, 997, 804, 1085, 917, 49, 208, 647, 191, 461, 968, 314, 822, 540, 93, 918, 1194, 63, 1000, 690, 585, 231, 704, 8, 74, 310, 507, 88 ], "train": [ 689, 907, 403, 285, 1073, 796, 316, 140, 1183, 132, 861, 1047, 167, 556, 985, 999, 1185, 891, 984, 199, 717, 40, 937, 354, 405, 271, 1155, 1125, 262, 900, 292, 304, 151, 909, 397, 580, 820, 646, 667, 1057, 183, 684, 760, 450, 1159, 726, 198, 747, 880, 1012, 335, 719, 39, 913, 84, 168, 207, 1051, 396, 618, 1088, 834, 567, 146, 923, 814, 921, 629, 299, 1139, 45, 734, 216, 758, 302, 424, 128, 60, 406, 85, 447, 205, 710, 547, 566, 32, 156, 18, 892, 105, 1170, 239, 659, 1061, 797, 76, 248, 541, 641, 783, 969, 1199, 508, 94, 58, 394, 639, 863, 479, 819, 1034, 487, 947, 862, 283, 305, 327, 438, 744, 975, 386, 723, 202, 756, 979, 236, 1092, 1094, 774, 343, 219, 920, 360, 945, 601, 810, 998, 568, 932, 197, 603, 830, 527, 808, 940, 46, 792, 875, 816, 1005, 506, 916, 634, 976, 244, 1009, 2, 1198, 290, 766, 1182, 840, 417, 182, 249, 210, 443, 623, 307, 590, 1043, 592, 925, 155, 858, 675, 90, 427, 442, 848, 846, 1093, 692, 1030, 1131, 35, 255, 111, 1166, 469, 1016, 22, 614, 377, 1018, 26, 195, 499, 165, 37, 333, 883, 899, 247, 906, 981, 288, 365, 616, 693, 1070, 7, 1135, 441, 1179, 874, 188, 992, 851, 347, 1054, 456, 258, 991, 529, 1110, 1066, 970, 69, 370, 457, 735, 1137, 546, 1050, 13, 879, 312, 348, 872, 23, 226, 611, 1130, 678, 445, 57, 557, 569, 853, 668, 532, 272, 806, 379, 977, 974, 223, 266, 385, 127, 971, 31, 562, 1112, 1058, 624, 753, 798, 1161, 300, 281, 980, 552, 253, 42, 142, 228, 486, 573, 436, 1103, 407, 459, 781, 1023, 1142, 780, 1077, 225, 175, 517, 714, 163, 29, 1157, 622, 14, 6, 961, 473, 274, 660, 157, 329, 295, 1017, 241, 617, 1065, 1021, 158, 731, 828, 811, 1064, 227, 378, 888, 565, 1063, 586, 1116, 411, 1098, 707, 232, 229, 87, 1100, 423, 833, 1041, 838, 1134, 612, 133, 896, 214, 109, 578, 946, 887, 645, 398, 1121, 539, 757, 670, 212, 211, 687, 572, 865, 1196, 700, 80, 584, 149, 588, 662, 351, 1136, 972, 542, 708, 325, 521, 455, 1111, 1096, 957, 432, 306, 786, 655, 1033, 458, 591, 850, 1193, 134, 380, 1174, 1148, 1144, 1040, 1010, 962, 200, 99, 683, 876, 384, 282, 308, 482, 1192, 1123, 857, 357, 894, 583, 669, 117, 615, 250, 785, 466, 564, 139, 575, 1090, 607, 50, 415, 1053, 1097, 868, 950, 66, 498, 1180, 1118, 174, 152, 953, 193, 967, 898, 1084, 77, 1019, 847, 53, 705, 776, 345, 649, 1056, 268, 581, 201, 490, 331, 318, 635, 619, 672, 519, 598, 27, 632, 108, 492, 534, 126, 59, 276, 812, 0, 251, 855, 484, 122, 448, 270, 1124, 596, 267, 416, 941, 835, 949, 964, 749, 930, 800, 277, 864, 356, 70, 15, 286, 71, 1013, 33, 1102, 1127, 460, 439, 897, 613, 332, 419, 172, 784, 504, 703, 933, 12, 844, 112, 606, 772, 948, 129, 4, 593, 408, 234, 993, 666, 637, 179, 512, 965, 839, 966, 110, 676, 463, 751, 421, 815, 465, 410, 1078, 794, 92, 1189, 1015, 485, 1067, 955, 233, 597, 1195, 119, 1059, 576, 135, 278, 1133, 440, 1035, 759, 147, 1107, 929, 344, 47, 245, 252, 914, 608, 654, 256, 369, 430, 885, 257, 339, 769, 145, 97, 680, 206, 323, 807, 1022, 752, 1138, 554, 787, 242, 604, 555, 738, 446, 721, 809, 171, 359, 106, 711, 130, 642, 322, 190, 1014, 488, 802, 362, 673, 178, 677, 904, 653, 1083, 788, 1105, 630, 799, 328, 775, 713, 1045, 350, 409, 387, 563, 1188, 651, 702, 915, 1122, 189, 919, 194, 480, 243, 101, 162, 334, 1106, 755, 905, 137, 813, 298, 452, 889, 170, 371, 767, 716, 79, 852, 326, 218, 373, 388, 644, 827, 62, 284, 535, 836, 574, 886, 238, 41, 681, 91, 643, 516, 153, 543, 1141, 901, 990, 648, 520, 95, 1049, 893, 658, 319, 661, 341, 311, 765, 264, 926, 309, 551, 881, 136, 1132, 1095, 958, 762, 453, 577, 1008, 550, 34, 963, 1, 729, 1024, 289, 10, 826, 301, 1074, 1087, 390, 1145, 166, 55, 1091, 640, 691, 1140, 790, 779, 1164, 934, 81, 28, 471, 514, 273, 470, 903, 1082, 496, 1129, 67, 777, 1002, 824, 939, 1081, 358, 1173, 209, 478, 177, 579, 1026, 1080, 493, 987, 694, 627, 1154, 141, 626, 1104, 922, 594, 600, 791, 204, 1143, 873, 1011, 220, 395, 620, 582, 1187, 221, 770, 1115, 973, 181, 1165, 803, 728, 842, 1120, 320, 11, 650, 1184, 682, 412, 192, 636, 545, 230, 743, 1089, 956, 763, 730, 338, 560, 368, 1169, 841, 890, 960, 748, 727, 259, 414, 280, 494, 54, 522, 402, 184, 754, 392, 609, 30, 878, 451, 477, 148, 989, 587, 782, 801, 367, 413, 240, 720, 1168, 161, 1197, 263, 526, 732, 254, 528, 1172, 854, 433, 924, 553, 159, 114, 293, 1119, 176, 449, 337, 843, 686, 725, 869, 399, 871, 131, 38, 663, 829, 697, 1079, 884, 1186, 426, 988, 1031, 558, 180, 364, 72, 98, 589, 837, 877, 599, 86, 715, 695, 712, 437, 561, 265, 610, 500, 160, 1007, 1150, 303, 696, 856, 503, 429, 823, 1153, 1027, 489, 538, 495, 275, 383, 817, 144, 468, 366, 297, 509, 722, 927, 44, 795, 481, 674, 1128, 1004, 1177, 709, 995, 400, 656, 342, 870, 1068, 82, 125, 936, 260, 124, 908, 404, 353, 771, 143, 938, 733, 401, 382, 1072, 118, 1038, 502, 773, 224, 78, 740, 978, 959, 24, 291, 5, 75, 628, 602, 1126, 213, 1036, 497, 1176, 420, 61, 462, 831, 16, 845, 688, 793, 860, 203, 313, 1006, 103, 422, 994, 895, 1062, 315, 381 ] } ================================================ FILE: datainfo/single_pendulum/data_split_dict_3.json ================================================ { "test": [ 1194, 644, 1184, 23, 694, 620, 1067, 1157, 895, 1155, 486, 883, 555, 1153, 210, 1152, 1065, 942, 771, 1121, 283, 737, 642, 695, 86, 319, 539, 597, 835, 404, 64, 1096, 221, 157, 14, 634, 1161, 483, 1035, 571, 677, 773, 92, 90, 243, 850, 1167, 601, 41, 1181, 840, 136, 704, 181, 990, 987, 129, 254, 583, 546, 432, 213, 1109, 668, 334, 572, 58, 689, 475, 834, 1093, 718, 790, 1038, 862, 1172, 893, 528, 444, 1013, 278, 73, 199, 748, 274, 910, 808, 874, 793, 968, 551, 63, 616, 87, 326, 131, 31, 798, 1071, 310, 474, 308, 813, 975, 1125, 1107, 963, 392, 479, 1128, 531, 960, 26, 134, 1189, 970, 757, 267, 1114, 487 ], "val": [ 710, 822, 925, 35, 46, 250, 829, 122, 293, 602, 878, 1029, 882, 232, 911, 349, 556, 295, 1190, 765, 678, 1079, 856, 467, 763, 33, 227, 734, 949, 972, 1145, 1158, 522, 637, 901, 852, 965, 1047, 4, 204, 159, 423, 1097, 36, 419, 581, 429, 297, 426, 649, 354, 1148, 277, 870, 812, 530, 298, 431, 132, 603, 353, 1051, 825, 175, 696, 570, 375, 645, 390, 69, 247, 460, 554, 923, 446, 1147, 163, 346, 897, 459, 683, 659, 208, 198, 891, 383, 671, 488, 1170, 455, 1024, 1115, 269, 55, 214, 772, 615, 540, 1086, 640, 379, 745, 363, 655, 611, 934, 514, 756, 43, 124, 45, 1002, 1119, 1074, 722, 954, 680, 123, 272, 1098 ], "train": [ 464, 1062, 890, 22, 660, 427, 290, 167, 469, 1159, 341, 712, 560, 457, 77, 952, 639, 462, 1185, 1195, 654, 135, 83, 797, 730, 372, 650, 218, 225, 787, 285, 265, 343, 279, 452, 1028, 397, 951, 948, 421, 855, 625, 158, 846, 1004, 1009, 753, 72, 466, 559, 914, 80, 591, 738, 142, 1120, 847, 408, 708, 973, 552, 236, 364, 1186, 207, 606, 1103, 338, 741, 622, 282, 178, 991, 263, 1054, 1112, 612, 586, 1072, 1193, 371, 1129, 480, 66, 1176, 781, 691, 507, 735, 853, 91, 575, 386, 192, 928, 67, 335, 785, 656, 585, 220, 789, 374, 347, 706, 231, 196, 188, 752, 681, 348, 324, 1177, 309, 151, 727, 929, 525, 688, 100, 95, 1141, 1124, 352, 519, 1090, 624, 1075, 623, 266, 351, 219, 723, 1133, 470, 1085, 75, 38, 1150, 510, 1031, 998, 800, 8, 641, 676, 370, 104, 288, 766, 415, 286, 292, 521, 411, 42, 1182, 443, 1110, 1092, 811, 873, 1073, 731, 842, 977, 350, 561, 788, 424, 1179, 312, 876, 906, 141, 321, 228, 564, 294, 791, 414, 922, 76, 118, 328, 138, 1166, 1156, 356, 628, 714, 841, 777, 1142, 799, 803, 103, 875, 844, 638, 1063, 1055, 726, 953, 1113, 313, 518, 315, 449, 935, 1160, 65, 881, 482, 6, 633, 784, 296, 498, 10, 1081, 337, 331, 865, 867, 434, 1169, 380, 605, 394, 154, 889, 884, 393, 412, 143, 217, 500, 587, 783, 815, 1052, 945, 535, 262, 961, 1007, 635, 12, 284, 61, 1089, 767, 672, 1076, 1197, 607, 717, 1117, 327, 999, 866, 1149, 0, 399, 1010, 345, 230, 994, 667, 121, 1036, 49, 1144, 1199, 1101, 992, 489, 451, 249, 533, 979, 927, 137, 666, 595, 496, 770, 959, 212, 534, 1099, 978, 1100, 197, 89, 697, 686, 171, 913, 1017, 477, 1146, 211, 558, 827, 1108, 201, 156, 619, 1039, 179, 915, 1088, 679, 437, 169, 657, 96, 27, 246, 502, 627, 684, 473, 647, 1005, 53, 739, 1046, 303, 485, 851, 29, 955, 373, 576, 699, 172, 15, 1131, 287, 930, 725, 589, 238, 653, 148, 454, 304, 1171, 511, 59, 318, 170, 84, 97, 857, 1162, 48, 947, 112, 481, 252, 703, 941, 1026, 617, 981, 837, 1021, 16, 1042, 162, 664, 838, 760, 517, 1059, 251, 1033, 420, 1154, 886, 71, 1198, 325, 1104, 257, 1015, 299, 830, 1084, 614, 848, 506, 721, 966, 849, 715, 596, 1018, 107, 888, 759, 418, 305, 669, 504, 1014, 608, 1027, 536, 368, 19, 233, 1122, 224, 796, 1020, 447, 234, 698, 11, 545, 956, 750, 106, 216, 907, 34, 1135, 242, 899, 200, 82, 818, 450, 1140, 476, 209, 21, 402, 1053, 161, 248, 113, 387, 357, 819, 543, 833, 32, 1095, 1048, 916, 1087, 40, 366, 17, 396, 578, 1006, 439, 1080, 68, 983, 355, 682, 445, 168, 165, 648, 9, 898, 743, 39, 259, 119, 316, 1061, 950, 25, 472, 574, 782, 405, 513, 971, 1056, 456, 858, 674, 145, 289, 57, 986, 786, 1044, 1180, 709, 416, 821, 108, 1019, 940, 693, 828, 744, 646, 805, 177, 503, 239, 905, 2, 859, 526, 824, 764, 920, 317, 871, 562, 417, 1049, 631, 300, 378, 610, 206, 588, 174, 1043, 407, 149, 912, 621, 1008, 193, 189, 1139, 673, 565, 817, 155, 114, 816, 361, 806, 458, 413, 340, 362, 700, 253, 794, 516, 685, 845, 705, 832, 1187, 900, 1078, 754, 367, 401, 976, 755, 185, 969, 1041, 180, 1105, 768, 468, 652, 944, 854, 974, 629, 144, 573, 110, 1130, 742, 1060, 775, 153, 186, 724, 301, 60, 924, 520, 505, 936, 538, 939, 579, 1037, 1058, 860, 1050, 461, 377, 807, 1064, 99, 256, 1057, 448, 115, 442, 78, 281, 917, 885, 557, 932, 400, 872, 701, 215, 747, 497, 388, 1034, 339, 342, 30, 919, 344, 670, 463, 365, 1003, 1016, 270, 780, 120, 1173, 93, 894, 187, 320, 663, 173, 261, 762, 636, 191, 140, 264, 194, 746, 260, 769, 880, 728, 235, 478, 964, 903, 74, 988, 809, 1025, 1143, 594, 1188, 1094, 736, 491, 758, 147, 879, 329, 1196, 861, 1138, 176, 205, 391, 490, 1164, 1083, 314, 222, 1132, 692, 598, 609, 938, 687, 302, 550, 237, 826, 887, 599, 150, 957, 273, 139, 190, 863, 453, 720, 1151, 509, 98, 166, 152, 410, 658, 1000, 495, 702, 529, 690, 593, 582, 802, 425, 164, 422, 240, 512, 776, 85, 823, 569, 358, 406, 613, 376, 931, 403, 716, 630, 465, 255, 661, 1163, 1030, 79, 184, 761, 332, 563, 962, 566, 102, 567, 109, 943, 732, 541, 523, 779, 37, 1123, 306, 592, 291, 713, 493, 1118, 117, 1192, 804, 933, 24, 183, 384, 675, 101, 792, 94, 896, 1070, 864, 241, 146, 892, 5, 995, 1127, 105, 544, 577, 1174, 751, 111, 1040, 385, 542, 1178, 749, 982, 707, 1011, 1069, 268, 1134, 1045, 333, 28, 133, 436, 229, 435, 868, 604, 711, 276, 548, 223, 996, 18, 909, 428, 1191, 430, 869, 719, 778, 508, 1102, 381, 937, 441, 993, 44, 820, 651, 1082, 1032, 926, 665, 618, 471, 382, 1068, 182, 580, 81, 814, 389, 740, 733, 3, 568, 130, 126, 438, 336, 195, 271, 369, 311, 600, 398, 662, 395, 359, 1116, 322, 160, 323, 501, 1066, 527, 88, 729, 1126, 985, 494, 70, 902, 127, 47, 1136, 1168, 56, 877, 1, 839, 795, 330, 484, 52, 433, 532, 1106, 632, 275, 1137, 1012, 774, 51, 409, 1091, 584, 643, 499, 7, 843, 1175, 1022, 1165, 280, 810, 245, 958, 967, 836, 908, 547, 125, 202, 226, 360, 62, 801, 831, 997, 553, 989, 258, 128, 1183, 116, 626, 1111, 54, 1001, 549, 537, 20, 980, 244, 307, 515, 1023, 1077, 984, 492, 13, 50, 590, 440, 921, 904, 918, 203, 946, 524 ] } ================================================ FILE: datainfo/swingstick_magnetic/data_split_dict_1.json ================================================ { "test": [ 2, 18, 4 ], "val": [ 3, 8 ], "train": [ 19, 5, 13, 9, 20, 21, 11, 12, 0, 16, 1, 17, 6, 10, 7, 14, 15 ] } ================================================ FILE: datainfo/swingstick_magnetic/data_split_dict_2.json ================================================ { "test": [ 20, 2, 1 ], "val": [ 5, 11 ], "train": [ 18, 21, 13, 10, 4, 17, 7, 15, 6, 19, 12, 0, 14, 3, 16, 8, 9 ] } ================================================ FILE: datainfo/swingstick_magnetic/data_split_dict_3.json ================================================ { "test": [ 17, 18, 7 ], "val": [ 11, 4 ], "train": [ 1, 13, 16, 5, 10, 6, 19, 12, 14, 3, 8, 20, 21, 0, 9, 2, 15 ] } ================================================ FILE: datainfo/swingstick_non_magnetic/data_split_dict_1.json ================================================ { "test": [ 48, 60, 57, 63, 15, 32, 8, 72, 17 ], "val": [ 77, 0, 55, 49, 3, 62, 12, 26 ], "train": [ 69, 20, 80, 10, 4, 59, 67, 81, 5, 2, 37, 78, 83, 58, 71, 19, 38, 73, 25, 79, 70, 30, 9, 36, 51, 52, 53, 16, 54, 23, 47, 21, 7, 39, 11, 6, 45, 74, 50, 18, 65, 42, 44, 22, 75, 35, 31, 28, 14, 33, 76, 46, 27, 64, 43, 24, 56, 68, 66, 41, 61, 82, 1, 40, 13, 29, 34 ] } ================================================ FILE: datainfo/swingstick_non_magnetic/data_split_dict_2.json ================================================ { "test": [ 4, 27, 32, 39, 21, 46, 10, 11, 7 ], "val": [ 64, 56, 47, 65, 50, 55, 20, 74 ], "train": [ 0, 31, 70, 6, 49, 41, 2, 67, 17, 13, 62, 9, 5, 57, 19, 25, 78, 18, 37, 79, 45, 51, 83, 16, 69, 71, 36, 66, 12, 61, 30, 54, 38, 40, 22, 42, 72, 26, 28, 63, 53, 43, 23, 44, 77, 8, 48, 60, 52, 1, 14, 15, 82, 35, 81, 33, 68, 76, 24, 58, 73, 59, 29, 80, 3, 75, 34 ] } ================================================ FILE: datainfo/swingstick_non_magnetic/data_split_dict_3.json ================================================ { "test": [ 8, 74, 60, 77, 47, 16, 69, 75, 30 ], "val": [ 68, 73, 24, 29, 70, 33, 78, 1 ], "train": [ 3, 15, 14, 22, 21, 51, 46, 52, 35, 5, 41, 56, 40, 54, 62, 57, 53, 7, 64, 26, 45, 58, 11, 18, 12, 32, 67, 9, 34, 20, 44, 43, 80, 13, 31, 39, 66, 6, 23, 82, 28, 36, 25, 27, 61, 38, 83, 17, 76, 63, 2, 37, 48, 10, 4, 49, 42, 0, 79, 81, 72, 59, 55, 65, 71, 19, 50 ] } ================================================ FILE: dataset.py ================================================ import os import sys import json import glob import torch import itertools import numpy as np from PIL import Image from scipy import misc from torch.utils.data import Dataset, DataLoader from torchvision import datasets, transforms class NeuralPhysDataset(Dataset): def __init__(self, data_filepath, flag, seed, object_name="double_pendulum"): self.seed = seed self.flag = flag self.object_name = object_name self.data_filepath = data_filepath self.all_filelist = self.get_all_filelist() def get_all_filelist(self): filelist = [] obj_filepath = os.path.join(self.data_filepath, self.object_name) # get the video ids based on training or testing data with open(os.path.join('../datainfo', self.object_name, f'data_split_dict_{self.seed}.json'), 'r') as file: seq_dict = json.load(file) vid_list = seq_dict[self.flag] # go through all the selected videos and get the triplets: input(t, t+1), output(t+2) for vid_idx in vid_list: seq_filepath = os.path.join(obj_filepath, str(vid_idx)) num_frames = len(os.listdir(seq_filepath)) suf = os.listdir(seq_filepath)[0].split('.')[-1] for p_frame in range(num_frames - 3): par_list = [] for p in range(4): par_list.append(os.path.join(seq_filepath, str(p_frame + p) + '.' + suf)) filelist.append(par_list) return filelist def __len__(self): return len(self.all_filelist) # 0, 1 -> 2, 3 def __getitem__(self, idx): par_list = self.all_filelist[idx] data = [] for i in range(2): data.append(self.get_data(par_list[i])) # 0, 1 data = torch.cat(data, 2) target = [] target.append(self.get_data(par_list[-2])) # 2 target.append(self.get_data(par_list[-1])) # 3 target = torch.cat(target, 2) filepath = '_'.join(par_list[0].split('/')[-2:]) return data, target, filepath def get_data(self, filepath): data = Image.open(filepath) data = data.resize((128, 128)) data = np.array(data) data = torch.tensor(data / 255.0) data = data.permute(2, 0, 1).float() return data class NeuralPhysLatentDynamicsDataset(Dataset): def __init__(self, data_filepath, flag, seed, object_name="double_pendulum"): self.seed = seed self.flag = flag self.object_name = object_name self.data_filepath = data_filepath self.all_filelist = self.get_all_filelist() def get_all_filelist(self): filelist = [] obj_filepath = os.path.join(self.data_filepath, self.object_name) # get the video ids based on training or testing data with open(os.path.join('../datainfo', self.object_name, f'data_split_dict_{self.seed}.json'), 'r') as file: seq_dict = json.load(file) vid_list = seq_dict[self.flag] # go through all the selected videos and get the triplets: input(t, t+1), output(t+2) for vid_idx in vid_list: seq_filepath = os.path.join(obj_filepath, str(vid_idx)) num_frames = len(os.listdir(seq_filepath)) suf = os.listdir(seq_filepath)[0].split('.')[-1] for p_frame in range(num_frames - 5): par_list = [] for p in range(6): par_list.append(os.path.join(seq_filepath, str(p_frame + p) + '.' + suf)) filelist.append(par_list) return filelist def __len__(self): return len(self.all_filelist) # 0, 1 -> 2, 3 def __getitem__(self, idx): par_list = self.all_filelist[idx] data = [] for i in range(2): data.append(self.get_data(par_list[i])) # 0, 1 data = torch.cat(data, 2) target = [] target.append(self.get_data(par_list[2])) # 2 target.append(self.get_data(par_list[3])) # 3 target = torch.cat(target, 2) target_target = [] target_target.append(self.get_data(par_list[-2])) # 4 target_target.append(self.get_data(par_list[-1])) # 5 target_target = torch.cat(target_target, 2) filepath = '_'.join(par_list[0].split('/')[-2:]) return data, target, target_target, filepath def get_data(self, filepath): data = Image.open(filepath) data = data.resize((128, 128)) data = np.array(data) data = torch.tensor(data / 255.0) data = data.permute(2, 0, 1).float() return data ================================================ FILE: eval.py ================================================ import os import sys import glob import yaml import torch import pprint import shutil import numpy as np from tqdm import tqdm from munch import munchify from collections import OrderedDict from models import VisDynamicsModel from models_latentpred import VisLatentDynamicsModel from dataset import NeuralPhysDataset from pytorch_lightning import Trainer, seed_everything from pytorch_lightning.loggers import TensorBoardLogger def mkdir(folder): if os.path.exists(folder): shutil.rmtree(folder) os.makedirs(folder) def load_config(filepath): with open(filepath, 'r') as stream: try: trainer_params = yaml.safe_load(stream) return trainer_params except yaml.YAMLError as exc: print(exc) def seed(cfg): torch.manual_seed(cfg.seed) if cfg.if_cuda: torch.cuda.manual_seed(cfg.seed) def main(): config_filepath = str(sys.argv[1]) checkpoint_filepath = str(sys.argv[2]) checkpoint_filepath = glob.glob(os.path.join(checkpoint_filepath, '*.ckpt'))[0] cfg = load_config(filepath=config_filepath) pprint.pprint(cfg) cfg = munchify(cfg) seed(cfg) seed_everything(cfg.seed) log_dir = '_'.join([cfg.log_dir, cfg.dataset, cfg.model_name, str(cfg.seed)]) model = VisDynamicsModel(lr=cfg.lr, seed=cfg.seed, if_cuda=cfg.if_cuda, if_test=True, gamma=cfg.gamma, log_dir=log_dir, train_batch=cfg.train_batch, val_batch=cfg.val_batch, test_batch=cfg.test_batch, num_workers=cfg.num_workers, model_name=cfg.model_name, data_filepath=cfg.data_filepath, dataset=cfg.dataset, lr_schedule=cfg.lr_schedule) ckpt = torch.load(checkpoint_filepath) if 'refine' in cfg.model_name: ckpt = rename_ckpt_for_multi_models(ckpt) model.model.load_state_dict(ckpt) high_dim_checkpoint_filepath = str(sys.argv[3]) high_dim_checkpoint_filepath = glob.glob(os.path.join(high_dim_checkpoint_filepath, '*.ckpt'))[0] ckpt = torch.load(high_dim_checkpoint_filepath) ckpt = rename_ckpt_for_multi_models(ckpt) model.high_dim_model.load_state_dict(ckpt) else: model.load_state_dict(ckpt['state_dict']) model.eval() model.freeze() trainer = Trainer(gpus=1, deterministic=True, amp_backend='native', default_root_dir=log_dir, val_check_interval=1.0) trainer.test(model) model.test_save() def main_latentpred(): config_filepath = str(sys.argv[1]) checkpoint_filepath = str(sys.argv[2]) high_dim_checkpoint_filepath = str(sys.argv[3]) refine_checkpoint_filepath = str(sys.argv[4]) cfg = load_config(filepath=config_filepath) pprint.pprint(cfg) cfg = munchify(cfg) seed(cfg) seed_everything(cfg.seed) log_dir = '_'.join([cfg.log_dir, cfg.dataset, cfg.model_name, str(cfg.seed)]) model = VisLatentDynamicsModel(lr=cfg.lr, seed=cfg.seed, if_cuda=cfg.if_cuda, if_test=True, gamma=cfg.gamma, log_dir=log_dir, train_batch=cfg.train_batch, val_batch=cfg.val_batch, test_batch=cfg.test_batch, num_workers=cfg.num_workers, model_name=cfg.model_name, data_filepath=cfg.data_filepath, dataset=cfg.dataset, lr_schedule=cfg.lr_schedule) model.load_model(checkpoint_filepath) model.load_high_dim_refine_model(high_dim_checkpoint_filepath, refine_checkpoint_filepath) model.extract_decoder_from_refine_model() model.eval() model.freeze() trainer = Trainer(gpus=1, deterministic=True, amp_backend='native', default_root_dir=log_dir, val_check_interval=1.0) trainer.test(model) def rename_ckpt_for_multi_models(ckpt): renamed_state_dict = OrderedDict() for k, v in ckpt['state_dict'].items(): if 'high_dim_model' in k: name = k.replace('high_dim_model.', '') else: name = k.replace('model.', '') renamed_state_dict[name] = v return renamed_state_dict # gather latent variables by running training data on the trained high-dim models def gather_latent_from_trained_high_dim_model(): config_filepath = str(sys.argv[1]) checkpoint_filepath = str(sys.argv[2]) checkpoint_filepath = glob.glob(os.path.join(checkpoint_filepath, '*.ckpt'))[0] cfg = load_config(filepath=config_filepath) pprint.pprint(cfg) cfg = munchify(cfg) seed(cfg) seed_everything(cfg.seed) log_dir = '_'.join([cfg.log_dir, cfg.dataset, cfg.model_name, str(cfg.seed)]) model = VisDynamicsModel(lr=cfg.lr, seed=cfg.seed, if_cuda=cfg.if_cuda, if_test=True, gamma=cfg.gamma, log_dir=log_dir, train_batch=cfg.train_batch, val_batch=cfg.val_batch, test_batch=cfg.test_batch, num_workers=cfg.num_workers, model_name=cfg.model_name, data_filepath=cfg.data_filepath, dataset=cfg.dataset, lr_schedule=cfg.lr_schedule) ckpt = torch.load(checkpoint_filepath) model.load_state_dict(ckpt['state_dict']) model = model.to('cuda') model.eval() model.freeze() # prepare train and val dataset kwargs = {'num_workers': cfg.num_workers, 'pin_memory': True} if cfg.if_cuda else {} train_dataset = NeuralPhysDataset(data_filepath=cfg.data_filepath, flag='train', seed=cfg.seed, object_name=cfg.dataset) val_dataset = NeuralPhysDataset(data_filepath=cfg.data_filepath, flag='val', seed=cfg.seed, object_name=cfg.dataset) # prepare train and val loader train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=cfg.train_batch, shuffle=True, **kwargs) val_loader = torch.utils.data.DataLoader(dataset=val_dataset, batch_size=cfg.val_batch, shuffle=False, **kwargs) # run train forward pass to save the latent vector for training the refine network later all_filepaths = [] all_latents = [] var_log_dir = os.path.join(log_dir, 'variables') for batch_idx, (data, target, filepath) in enumerate(tqdm(train_loader)): if cfg.model_name == 'encoder-decoder': output, latent = model.model(data.cuda()) if cfg.model_name == 'encoder-decoder-64': output, latent = model.model(data.cuda(), data.cuda(), False) # save the latent vectors all_filepaths.extend(filepath) for idx in range(data.shape[0]): latent_tmp = latent[idx].view(1, -1)[0] latent_tmp = latent_tmp.cpu().detach().numpy() all_latents.append(latent_tmp) mkdir(var_log_dir+'_train') np.save(os.path.join(var_log_dir+'_train', 'ids.npy'), all_filepaths) np.save(os.path.join(var_log_dir+'_train', 'latent.npy'), all_latents) # run val forward pass to save the latent vector for validating the refine network later all_filepaths = [] all_latents = [] var_log_dir = os.path.join(log_dir, 'variables') for batch_idx, (data, target, filepath) in enumerate(tqdm(val_loader)): if cfg.model_name == 'encoder-decoder': output, latent = model.model(data.cuda()) if cfg.model_name == 'encoder-decoder-64': output, latent = model.model(data.cuda(), data.cuda(), False) # save the latent vectors all_filepaths.extend(filepath) for idx in range(data.shape[0]): latent_tmp = latent[idx].view(1, -1)[0] latent_tmp = latent_tmp.cpu().detach().numpy() all_latents.append(latent_tmp) mkdir(var_log_dir+'_val') np.save(os.path.join(var_log_dir+'_val', 'ids.npy'), all_filepaths) np.save(os.path.join(var_log_dir+'_val', 'latent.npy'), all_latents) # gather latent variables by running training data on the trained refine models def gather_latent_from_trained_refine_model(): config_filepath = str(sys.argv[1]) checkpoint_filepath = str(sys.argv[2]) checkpoint_filepath = glob.glob(os.path.join(checkpoint_filepath, '*.ckpt'))[0] cfg = load_config(filepath=config_filepath) pprint.pprint(cfg) cfg = munchify(cfg) seed(cfg) seed_everything(cfg.seed) log_dir = '_'.join([cfg.log_dir, cfg.dataset, cfg.model_name, str(cfg.seed)]) model = VisDynamicsModel(lr=cfg.lr, seed=cfg.seed, if_cuda=cfg.if_cuda, if_test=True, gamma=cfg.gamma, log_dir=log_dir, train_batch=cfg.train_batch, val_batch=cfg.val_batch, test_batch=cfg.test_batch, num_workers=cfg.num_workers, model_name=cfg.model_name, data_filepath=cfg.data_filepath, dataset=cfg.dataset, lr_schedule=cfg.lr_schedule) ckpt = torch.load(checkpoint_filepath) ckpt = rename_ckpt_for_multi_models(ckpt) model.model.load_state_dict(ckpt) high_dim_checkpoint_filepath = str(sys.argv[3]) high_dim_checkpoint_filepath = glob.glob(os.path.join(high_dim_checkpoint_filepath, '*.ckpt'))[0] ckpt = torch.load(high_dim_checkpoint_filepath) ckpt = rename_ckpt_for_multi_models(ckpt) model.high_dim_model.load_state_dict(ckpt) model = model.to('cuda') model.eval() model.freeze() # prepare train and val dataset kwargs = {'num_workers': cfg.num_workers, 'pin_memory': True} if cfg.if_cuda else {} train_dataset = NeuralPhysDataset(data_filepath=cfg.data_filepath, flag='train', seed=cfg.seed, object_name=cfg.dataset) val_dataset = NeuralPhysDataset(data_filepath=cfg.data_filepath, flag='val', seed=cfg.seed, object_name=cfg.dataset) # prepare train and val loader train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=cfg.train_batch, shuffle=True, **kwargs) val_loader = torch.utils.data.DataLoader(dataset=val_dataset, batch_size=cfg.val_batch, shuffle=False, **kwargs) # run train forward pass to save the latent vector for training data all_filepaths = [] all_latents = [] all_refine_latents = [] all_reconstructed_latents = [] var_log_dir = os.path.join(log_dir, 'variables') for batch_idx, (data, target, filepath) in enumerate(tqdm(train_loader)): _, latent = model.high_dim_model(data.cuda(), data.cuda(), False) latent = latent.squeeze(-1).squeeze(-1) latent_reconstructed, latent_latent = model.model(latent) # save the latent vectors all_filepaths.extend(filepath) for idx in range(data.shape[0]): latent_tmp = latent[idx].view(1, -1)[0] latent_tmp = latent_tmp.cpu().detach().numpy() all_latents.append(latent_tmp) # save latent_latent: the latent vector in the refine network latent_latent_tmp = latent_latent[idx].view(1, -1)[0] latent_latent_tmp = latent_latent_tmp.cpu().detach().numpy() all_refine_latents.append(latent_latent_tmp) # save latent_reconstructed: the latent vector reconstructed by the entire refine network latent_reconstructed_tmp = latent_reconstructed[idx].view(1, -1)[0] latent_reconstructed_tmp = latent_reconstructed_tmp.cpu().detach().numpy() all_reconstructed_latents.append(latent_reconstructed_tmp) mkdir(var_log_dir+'_train') np.save(os.path.join(var_log_dir+'_train', 'ids.npy'), all_filepaths) np.save(os.path.join(var_log_dir+'_train', 'latent.npy'), all_latents) np.save(os.path.join(var_log_dir+'_train', 'refine_latent.npy'), all_refine_latents) np.save(os.path.join(var_log_dir+'_train', 'reconstructed_latent.npy'), all_reconstructed_latents) # run val forward pass to save the latent vector for validation data all_filepaths = [] all_latents = [] all_refine_latents = [] all_reconstructed_latents = [] var_log_dir = os.path.join(log_dir, 'variables') for batch_idx, (data, target, filepath) in enumerate(tqdm(val_loader)): _, latent = model.high_dim_model(data.cuda(), data.cuda(), False) latent = latent.squeeze(-1).squeeze(-1) latent_reconstructed, latent_latent = model.model(latent) # save the latent vectors all_filepaths.extend(filepath) for idx in range(data.shape[0]): latent_tmp = latent[idx].view(1, -1)[0] latent_tmp = latent_tmp.cpu().detach().numpy() all_latents.append(latent_tmp) # save latent_latent: the latent vector in the refine network latent_latent_tmp = latent_latent[idx].view(1, -1)[0] latent_latent_tmp = latent_latent_tmp.cpu().detach().numpy() all_refine_latents.append(latent_latent_tmp) # save latent_reconstructed: the latent vector reconstructed by the entire refine network latent_reconstructed_tmp = latent_reconstructed[idx].view(1, -1)[0] latent_reconstructed_tmp = latent_reconstructed_tmp.cpu().detach().numpy() all_reconstructed_latents.append(latent_reconstructed_tmp) mkdir(var_log_dir+'_val') np.save(os.path.join(var_log_dir+'_val', 'ids.npy'), all_filepaths) np.save(os.path.join(var_log_dir+'_val', 'latent.npy'), all_latents) np.save(os.path.join(var_log_dir+'_val', 'refine_latent.npy'), all_refine_latents) np.save(os.path.join(var_log_dir+'_val', 'reconstructed_latent.npy'), all_reconstructed_latents) if __name__ == '__main__': if str(sys.argv[4]) == 'eval-train': gather_latent_from_trained_high_dim_model() elif str(sys.argv[4]) == 'eval-refine-train': gather_latent_from_trained_refine_model() elif str(sys.argv[5]) == 'eval-latentpred': main_latentpred() else: main() ================================================ FILE: main.py ================================================ import os import sys import yaml import torch import pprint from munch import munchify from models import VisDynamicsModel from models_latentpred import VisLatentDynamicsModel from pytorch_lightning import Trainer, seed_everything from pytorch_lightning.loggers import TensorBoardLogger from pytorch_lightning.callbacks import ModelCheckpoint def load_config(filepath): with open(filepath, 'r') as stream: try: trainer_params = yaml.safe_load(stream) return trainer_params except yaml.YAMLError as exc: print(exc) def seed(cfg): torch.manual_seed(cfg.seed) if cfg.if_cuda: torch.cuda.manual_seed(cfg.seed) def main(): config_filepath = str(sys.argv[1]) cfg = load_config(filepath=config_filepath) pprint.pprint(cfg) cfg = munchify(cfg) seed(cfg) seed_everything(cfg.seed) log_dir = '_'.join([cfg.log_dir, cfg.dataset, cfg.model_name, str(cfg.seed)]) model = VisDynamicsModel(lr=cfg.lr, seed=cfg.seed, if_cuda=cfg.if_cuda, if_test=False, gamma=cfg.gamma, log_dir=log_dir, train_batch=cfg.train_batch, val_batch=cfg.val_batch, test_batch=cfg.test_batch, num_workers=cfg.num_workers, model_name=cfg.model_name, data_filepath=cfg.data_filepath, dataset=cfg.dataset, lr_schedule=cfg.lr_schedule) # define callback for selecting checkpoints during training checkpoint_callback = ModelCheckpoint( filepath=log_dir + "/lightning_logs/checkpoints/{epoch}_{val_loss}", verbose=True, monitor='val_loss', mode='min', prefix='') # define trainer trainer = Trainer(gpus=cfg.num_gpus, max_epochs=cfg.epochs, deterministic=True, accelerator='ddp', amp_backend='native', default_root_dir=log_dir, val_check_interval=1.0, checkpoint_callback=checkpoint_callback) trainer.fit(model) def main_latentpred(): config_filepath = str(sys.argv[2]) high_dim_checkpoint_filepath = str(sys.argv[3]) refine_checkpoint_filepath = str(sys.argv[4]) cfg = load_config(filepath=config_filepath) pprint.pprint(cfg) cfg = munchify(cfg) seed(cfg) seed_everything(cfg.seed) log_dir = '_'.join([cfg.log_dir, cfg.dataset, cfg.model_name, str(cfg.seed)]) model = VisLatentDynamicsModel(lr=cfg.lr, seed=cfg.seed, if_cuda=cfg.if_cuda, if_test=False, gamma=cfg.gamma, log_dir=log_dir, train_batch=cfg.train_batch, val_batch=cfg.val_batch, test_batch=cfg.test_batch, num_workers=cfg.num_workers, model_name=cfg.model_name, data_filepath=cfg.data_filepath, dataset=cfg.dataset, lr_schedule=cfg.lr_schedule) model.load_high_dim_refine_model(high_dim_checkpoint_filepath, refine_checkpoint_filepath) # define callback for selecting checkpoints during training checkpoint_callback = ModelCheckpoint( filepath=log_dir + "/lightning_logs/checkpoints/{epoch}_{val_loss}", verbose=True, monitor='val_loss', mode='min', prefix='') # define trainer trainer = Trainer(gpus=cfg.num_gpus, max_epochs=cfg.epochs, deterministic=True, accelerator='ddp', amp_backend='native', default_root_dir=log_dir, val_check_interval=1.0, checkpoint_callback=checkpoint_callback) trainer.fit(model) if __name__ == '__main__': if sys.argv[1] == 'latentpred': main_latentpred() else: main() ================================================ FILE: model_utils.py ================================================ import torch import numpy as np import torch.nn as nn import torch.nn.functional as F def conv2d_bn_relu(inch,outch,kernel_size,stride=1,padding=1): convlayer = torch.nn.Sequential( torch.nn.Conv2d(inch,outch,kernel_size=kernel_size,stride=stride,padding=padding), torch.nn.BatchNorm2d(outch), torch.nn.ReLU() ) return convlayer def conv2d_bn_sigmoid(inch,outch,kernel_size,stride=1,padding=1): convlayer = torch.nn.Sequential( torch.nn.Conv2d(inch,outch,kernel_size=kernel_size,stride=stride,padding=padding), torch.nn.BatchNorm2d(outch), torch.nn.Sigmoid() ) return convlayer def deconv_sigmoid(inch,outch,kernel_size,stride=1,padding=1): convlayer = torch.nn.Sequential( torch.nn.ConvTranspose2d(inch,outch,kernel_size=kernel_size,stride=stride,padding=padding), torch.nn.Sigmoid() ) return convlayer def deconv_relu(inch,outch,kernel_size,stride=1,padding=1): convlayer = torch.nn.Sequential( torch.nn.ConvTranspose2d(inch,outch,kernel_size=kernel_size,stride=stride,padding=padding), torch.nn.BatchNorm2d(outch), torch.nn.ReLU() ) return convlayer class EncoderDecoder(torch.nn.Module): def __init__(self, in_channels): super(EncoderDecoder,self).__init__() self.conv_stack1 = torch.nn.Sequential( conv2d_bn_relu(in_channels,32,4,stride=2), conv2d_bn_relu(32,32,3) ) self.conv_stack2 = torch.nn.Sequential( conv2d_bn_relu(32,32,4,stride=2), conv2d_bn_relu(32,32,3) ) self.conv_stack3 = torch.nn.Sequential( conv2d_bn_relu(32,64,4,stride=2), conv2d_bn_relu(64,64,3) ) self.conv_stack4 = torch.nn.Sequential( conv2d_bn_relu(64,128,4,stride=2), conv2d_bn_relu(128,128,3), ) self.conv_stack5 = torch.nn.Sequential( conv2d_bn_relu(128,128,(3,4),stride=(1,2)), conv2d_bn_relu(128,128,3), ) self.deconv_5 = deconv_relu(128,64,(3,4),stride=(1,2)) self.deconv_4 = deconv_relu(67,64,4,stride=2) self.deconv_3 = deconv_relu(67,32,4,stride=2) self.deconv_2 = deconv_relu(35,16,4,stride=2) self.deconv_1 = deconv_sigmoid(19,3,4,stride=2) self.predict_5 = torch.nn.Conv2d(128,3,3,stride=1,padding=1) self.predict_4 = torch.nn.Conv2d(67,3,3,stride=1,padding=1) self.predict_3 = torch.nn.Conv2d(67,3,3,stride=1,padding=1) self.predict_2 = torch.nn.Conv2d(35,3,3,stride=1,padding=1) self.up_sample_5 = torch.nn.Sequential( torch.nn.ConvTranspose2d(3,3,(3,4),stride=(1,2),padding=1,bias=False), torch.nn.Sigmoid() ) self.up_sample_4 = torch.nn.Sequential( torch.nn.ConvTranspose2d(3,3,4,stride=2,padding=1,bias=False), torch.nn.Sigmoid() ) self.up_sample_3 = torch.nn.Sequential( torch.nn.ConvTranspose2d(3,3,4,stride=2,padding=1,bias=False), torch.nn.Sigmoid() ) self.up_sample_2 = torch.nn.Sequential( torch.nn.ConvTranspose2d(3,3,4,stride=2,padding=1,bias=False), torch.nn.Sigmoid() ) def encoder(self, x): conv1_out = self.conv_stack1(x) conv2_out = self.conv_stack2(conv1_out) conv3_out = self.conv_stack3(conv2_out) conv4_out = self.conv_stack4(conv3_out) conv5_out = self.conv_stack5(conv4_out) return conv5_out def decoder(self, x): deconv5_out = self.deconv_5(x) predict_5_out = self.up_sample_5(self.predict_5(x)) concat_5 = torch.cat([deconv5_out, predict_5_out],dim=1) deconv4_out = self.deconv_4(concat_5) predict_4_out = self.up_sample_4(self.predict_4(concat_5)) concat_4 = torch.cat([deconv4_out,predict_4_out],dim=1) deconv3_out = self.deconv_3(concat_4) predict_3_out = self.up_sample_3(self.predict_3(concat_4)) concat2 = torch.cat([deconv3_out,predict_3_out],dim=1) deconv2_out = self.deconv_2(concat2) predict_2_out = self.up_sample_2(self.predict_2(concat2)) concat1 = torch.cat([deconv2_out,predict_2_out],dim=1) predict_out = self.deconv_1(concat1) return predict_out def forward(self,x): latent = self.encoder(x) out = self.decoder(latent) return out, latent class EncoderDecoder64x1x1(torch.nn.Module): def __init__(self, in_channels): super(EncoderDecoder64x1x1,self).__init__() self.conv_stack1 = torch.nn.Sequential( conv2d_bn_relu(in_channels,32,4,stride=2), conv2d_bn_relu(32,32,3) ) self.conv_stack2 = torch.nn.Sequential( conv2d_bn_relu(32,32,4,stride=2), conv2d_bn_relu(32,32,3) ) self.conv_stack3 = torch.nn.Sequential( conv2d_bn_relu(32,64,4,stride=2), conv2d_bn_relu(64,64,3) ) self.conv_stack4 = torch.nn.Sequential( conv2d_bn_relu(64,64,4,stride=2), conv2d_bn_relu(64,64,3), ) self.conv_stack5 = torch.nn.Sequential( conv2d_bn_relu(64,64,4,stride=2), conv2d_bn_relu(64,64,3), ) self.conv_stack6 = torch.nn.Sequential( conv2d_bn_relu(64,64,4,stride=2), conv2d_bn_relu(64,64,3), ) self.conv_stack7 = torch.nn.Sequential( conv2d_bn_relu(64,64,4,stride=2), conv2d_bn_relu(64,64,3), ) self.conv_stack8 = torch.nn.Sequential( conv2d_bn_relu(64, 64, (3,4), stride=(1,2)), conv2d_bn_relu(64, 64, 3), ) self.deconv_8 = deconv_relu(64,64,(3,4),stride=(1,2)) self.deconv_7 = deconv_relu(67,64,4,stride=2) self.deconv_6 = deconv_relu(67,64,4,stride=2) self.deconv_5 = deconv_relu(67,64,4,stride=2) self.deconv_4 = deconv_relu(67,64,4,stride=2) self.deconv_3 = deconv_relu(67,32,4,stride=2) self.deconv_2 = deconv_relu(35,16,4,stride=2) self.deconv_1 = deconv_sigmoid(19,3,4,stride=2) self.predict_8 = torch.nn.Conv2d(64,3,3,stride=1,padding=1) self.predict_7 = torch.nn.Conv2d(67,3,3,stride=1,padding=1) self.predict_6 = torch.nn.Conv2d(67,3,3,stride=1,padding=1) self.predict_5 = torch.nn.Conv2d(67,3,3,stride=1,padding=1) self.predict_4 = torch.nn.Conv2d(67,3,3,stride=1,padding=1) self.predict_3 = torch.nn.Conv2d(67,3,3,stride=1,padding=1) self.predict_2 = torch.nn.Conv2d(35,3,3,stride=1,padding=1) self.up_sample_8 = torch.nn.Sequential( torch.nn.ConvTranspose2d(3,3,(3,4),stride=(1,2),padding=1,bias=False), torch.nn.Sigmoid() ) self.up_sample_7 = torch.nn.Sequential( torch.nn.ConvTranspose2d(3,3,4,stride=2,padding=1,bias=False), torch.nn.Sigmoid() ) self.up_sample_6 = torch.nn.Sequential( torch.nn.ConvTranspose2d(3,3,4,stride=2,padding=1,bias=False), torch.nn.Sigmoid() ) self.up_sample_5 = torch.nn.Sequential( torch.nn.ConvTranspose2d(3,3,4,stride=2,padding=1,bias=False), torch.nn.Sigmoid() ) self.up_sample_4 = torch.nn.Sequential( torch.nn.ConvTranspose2d(3,3,4,stride=2,padding=1,bias=False), torch.nn.Sigmoid() ) self.up_sample_3 = torch.nn.Sequential( torch.nn.ConvTranspose2d(3,3,4,stride=2,padding=1,bias=False), torch.nn.Sigmoid() ) self.up_sample_2 = torch.nn.Sequential( torch.nn.ConvTranspose2d(3,3,4,stride=2,padding=1,bias=False), torch.nn.Sigmoid() ) def encoder(self, x): conv1_out = self.conv_stack1(x) conv2_out = self.conv_stack2(conv1_out) conv3_out = self.conv_stack3(conv2_out) conv4_out = self.conv_stack4(conv3_out) conv5_out = self.conv_stack5(conv4_out) conv6_out = self.conv_stack6(conv5_out) conv7_out = self.conv_stack7(conv6_out) conv8_out = self.conv_stack8(conv7_out) return conv8_out def decoder(self, x): deconv8_out = self.deconv_8(x) predict_8_out = self.up_sample_8(self.predict_8(x)) concat_7 = torch.cat([deconv8_out, predict_8_out], dim=1) deconv7_out = self.deconv_7(concat_7) predict_7_out = self.up_sample_7(self.predict_7(concat_7)) concat_6 = torch.cat([deconv7_out,predict_7_out],dim=1) deconv6_out = self.deconv_6(concat_6) predict_6_out = self.up_sample_6(self.predict_6(concat_6)) concat_5 = torch.cat([deconv6_out,predict_6_out],dim=1) deconv5_out = self.deconv_5(concat_5) predict_5_out = self.up_sample_5(self.predict_5(concat_5)) concat_4 = torch.cat([deconv5_out,predict_5_out],dim=1) deconv4_out = self.deconv_4(concat_4) predict_4_out = self.up_sample_4(self.predict_4(concat_4)) concat_3 = torch.cat([deconv4_out,predict_4_out],dim=1) deconv3_out = self.deconv_3(concat_3) predict_3_out = self.up_sample_3(self.predict_3(concat_3)) concat2 = torch.cat([deconv3_out,predict_3_out],dim=1) deconv2_out = self.deconv_2(concat2) predict_2_out = self.up_sample_2(self.predict_2(concat2)) concat1 = torch.cat([deconv2_out,predict_2_out],dim=1) predict_out = self.deconv_1(concat1) return predict_out def forward(self,x, reconstructed_latent, refine_latent): if refine_latent != True: latent = self.encoder(x) out = self.decoder(latent) return out, latent else: latent = self.encoder(x) out = self.decoder(reconstructed_latent) return out, latent class SirenLayer(nn.Module): def __init__(self, in_f, out_f, w0=30, is_first=False, is_last=False): super().__init__() self.in_f = in_f self.w0 = w0 self.linear = nn.Linear(in_f, out_f) self.is_first = is_first self.is_last = is_last self.init_weights() def init_weights(self): b = 1 / self.in_f if self.is_first else np.sqrt(6 / self.in_f) / self.w0 with torch.no_grad(): self.linear.weight.uniform_(-b, b) def forward(self, x): x = self.linear(x) return x if self.is_last else torch.sin(self.w0 * x) class RefineDoublePendulumModel(torch.nn.Module): def __init__(self, in_channels): super(RefineDoublePendulumModel, self).__init__() self.layer1 = SirenLayer(in_channels, 128, is_first=True) self.layer2 = SirenLayer(128, 64) self.layer3 = SirenLayer(64, 32) self.layer4 = SirenLayer(32, 4) self.layer5 = SirenLayer(4, 32) self.layer6 = SirenLayer(32, 64) self.layer7 = SirenLayer(64, 128) self.layer8 = SirenLayer(128, in_channels, is_last=True) def forward(self, x): x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) latent = self.layer4(x) x = self.layer5(latent) x = self.layer6(x) x = self.layer7(x) x = self.layer8(x) return x, latent class RefineElasticPendulumModel(torch.nn.Module): def __init__(self, in_channels): super(RefineElasticPendulumModel, self).__init__() self.layer1 = SirenLayer(in_channels, 128, is_first=True) self.layer2 = SirenLayer(128, 64) self.layer3 = SirenLayer(64, 32) self.layer4 = SirenLayer(32, 6) self.layer5 = SirenLayer(6, 32) self.layer6 = SirenLayer(32, 64) self.layer7 = SirenLayer(64, 128) self.layer8 = SirenLayer(128, in_channels, is_last=True) def forward(self, x): x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) latent = self.layer4(x) x = self.layer5(latent) x = self.layer6(x) x = self.layer7(x) x = self.layer8(x) return x, latent class RefineReactionDiffusionModel(torch.nn.Module): def __init__(self, in_channels): super(RefineReactionDiffusionModel, self).__init__() self.layer1 = SirenLayer(in_channels, 128, is_first=True) self.layer2 = SirenLayer(128, 64) self.layer3 = SirenLayer(64, 32) self.layer4 = SirenLayer(32, 2) self.layer5 = SirenLayer(2, 32) self.layer6 = SirenLayer(32, 64) self.layer7 = SirenLayer(64, 128) self.layer8 = SirenLayer(128, in_channels, is_last=True) def forward(self, x): x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) latent = self.layer4(x) x = self.layer5(latent) x = self.layer6(x) x = self.layer7(x) x = self.layer8(x) return x, latent class RefineSwingStickNonMagneticModel(torch.nn.Module): def __init__(self, in_channels): super(RefineSwingStickNonMagneticModel, self).__init__() self.layer1 = SirenLayer(in_channels, 128, is_first=True) self.layer2 = SirenLayer(128, 64) self.layer3 = SirenLayer(64, 32) self.layer4 = SirenLayer(32, 4) self.layer5 = SirenLayer(4, 32) self.layer6 = SirenLayer(32, 64) self.layer7 = SirenLayer(64, 128) self.layer8 = SirenLayer(128, in_channels, is_last=True) def forward(self, x): x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) latent = self.layer4(x) x = self.layer5(latent) x = self.layer6(x) x = self.layer7(x) x = self.layer8(x) return x, latent class RefineSinglePendulumModel(torch.nn.Module): def __init__(self, in_channels): super(RefineSinglePendulumModel, self).__init__() self.layer1 = SirenLayer(in_channels, 128, is_first=True) self.layer2 = SirenLayer(128, 64) self.layer3 = SirenLayer(64, 32) self.layer4 = SirenLayer(32, 2) self.layer5 = SirenLayer(2, 32) self.layer6 = SirenLayer(32, 64) self.layer7 = SirenLayer(64, 128) self.layer8 = SirenLayer(128, in_channels, is_last=True) def forward(self, x): x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) latent = self.layer4(x) x = self.layer5(latent) x = self.layer6(x) x = self.layer7(x) x = self.layer8(x) return x, latent class RefineCircularMotionModel(torch.nn.Module): def __init__(self, in_channels): super(RefineCircularMotionModel, self).__init__() self.layer1 = SirenLayer(in_channels, 128, is_first=True) self.layer2 = SirenLayer(128, 64) self.layer3 = SirenLayer(64, 32) self.layer4 = SirenLayer(32, 2) self.layer5 = SirenLayer(2, 32) self.layer6 = SirenLayer(32, 64) self.layer7 = SirenLayer(64, 128) self.layer8 = SirenLayer(128, in_channels, is_last=True) def forward(self, x): x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) latent = self.layer4(x) x = self.layer5(latent) x = self.layer6(x) x = self.layer7(x) x = self.layer8(x) return x, latent class RefineAirDancerModel(torch.nn.Module): def __init__(self, in_channels): super(RefineAirDancerModel, self).__init__() self.layer1 = SirenLayer(in_channels, 128, is_first=True) self.layer2 = SirenLayer(128, 64) self.layer3 = SirenLayer(64, 32) self.layer4 = SirenLayer(32, 8) self.layer5 = SirenLayer(8, 32) self.layer6 = SirenLayer(32, 64) self.layer7 = SirenLayer(64, 128) self.layer8 = SirenLayer(128, in_channels, is_last=True) def forward(self, x): x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) latent = self.layer4(x) x = self.layer5(latent) x = self.layer6(x) x = self.layer7(x) x = self.layer8(x) return x, latent class RefineLavaLampModel(torch.nn.Module): def __init__(self, in_channels): super(RefineLavaLampModel, self).__init__() self.layer1 = SirenLayer(in_channels, 128, is_first=True) self.layer2 = SirenLayer(128, 64) self.layer3 = SirenLayer(64, 32) self.layer4 = SirenLayer(32, 8) self.layer5 = SirenLayer(8, 32) self.layer6 = SirenLayer(32, 64) self.layer7 = SirenLayer(64, 128) self.layer8 = SirenLayer(128, in_channels, is_last=True) def forward(self, x): x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) latent = self.layer4(x) x = self.layer5(latent) x = self.layer6(x) x = self.layer7(x) x = self.layer8(x) return x, latent class RefineFireModel(torch.nn.Module): def __init__(self, in_channels): super(RefineFireModel, self).__init__() self.layer1 = SirenLayer(in_channels, 128, is_first=True) self.layer2 = SirenLayer(128, 64) self.layer3 = SirenLayer(64, 32) self.layer4 = SirenLayer(32, 24) self.layer5 = SirenLayer(24, 32) self.layer6 = SirenLayer(32, 64) self.layer7 = SirenLayer(64, 128) self.layer8 = SirenLayer(128, in_channels, is_last=True) def forward(self, x): x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) latent = self.layer4(x) x = self.layer5(latent) x = self.layer6(x) x = self.layer7(x) x = self.layer8(x) return x, latent class RefineModelReLU(torch.nn.Module): def __init__(self, in_channels): super(RefineModelReLU, self).__init__() self.layer1 = nn.Linear(in_channels, 128) self.relu1 = nn.ReLU() self.layer2 = nn.Linear(128, 64) self.relu2 = nn.ReLU() self.layer3 = nn.Linear(64, 4) self.relu3 = nn.ReLU() self.layer4 = nn.Linear(4, 64) self.relu4 = nn.ReLU() self.layer5 = nn.Linear(64, 128) self.relu5 = nn.ReLU() self.layer6 = nn.Linear(128, in_channels) def forward(self, x): x = self.relu1(self.layer1(x)) x = self.relu2(self.layer2(x)) latent = self.relu3(self.layer3(x)) x = self.relu4(self.layer4(latent)) x = self.relu5(self.layer5(x)) x = self.layer6(x) return x, latent class LatentPredModel(torch.nn.Module): def __init__(self, in_channels): super(LatentPredModel, self).__init__() self.layer1 = nn.Linear(in_channels, 32) self.relu1 = nn.ReLU() self.layer2 = nn.Linear(32, 64) self.relu2 = nn.ReLU() self.layer3 = nn.Linear(64, 64) self.relu3 = nn.ReLU() self.layer4 = nn.Linear(64, 64) self.relu4 = nn.ReLU() self.layer5 = nn.Linear(64, 32) self.relu5 = nn.ReLU() self.layer6 = nn.Linear(32, in_channels) def forward(self, x): x = self.relu1(self.layer1(x)) x = self.relu2(self.layer2(x)) x = self.relu3(self.layer3(x)) x = self.relu4(self.layer4(x)) x = self.relu5(self.layer5(x)) x = self.layer6(x) return x ================================================ FILE: models.py ================================================ import os import torch import shutil import numpy as np from torch import nn import pytorch_lightning as pl import torch.nn.functional as F from torchvision import transforms import torchvision.models as models from collections import OrderedDict from torch.utils.data import DataLoader from torchvision.utils import save_image from dataset import NeuralPhysDataset from model_utils import (EncoderDecoder, EncoderDecoder64x1x1, RefineDoublePendulumModel, RefineSinglePendulumModel, RefineCircularMotionModel, RefineModelReLU, RefineSwingStickNonMagneticModel, RefineAirDancerModel, RefineLavaLampModel, RefineFireModel, RefineElasticPendulumModel, RefineReactionDiffusionModel) def mkdir(folder): if os.path.exists(folder): shutil.rmtree(folder) os.makedirs(folder) class VisDynamicsModel(pl.LightningModule): def __init__(self, lr: float=1e-4, seed: int=1, if_cuda: bool=True, if_test: bool=False, gamma: float=0.5, log_dir: str='logs', train_batch: int=512, val_batch: int=256, test_batch: int=256, num_workers: int=8, model_name: str='encoder-decoder-64', data_filepath: str='data', dataset: str='single_pendulum', lr_schedule: list=[20, 50, 100]) -> None: super().__init__() self.save_hyperparameters() self.kwargs = {'num_workers': self.hparams.num_workers, 'pin_memory': True} if self.hparams.if_cuda else {} # create visualization saving folder if testing self.pred_log_dir = os.path.join(self.hparams.log_dir, 'predictions') self.var_log_dir = os.path.join(self.hparams.log_dir, 'variables') if not self.hparams.if_test: mkdir(self.pred_log_dir) mkdir(self.var_log_dir) self.__build_model() def __build_model(self): # model if self.hparams.model_name == 'encoder-decoder': self.model = EncoderDecoder(in_channels=3) if self.hparams.model_name == 'encoder-decoder-64': self.model = EncoderDecoder64x1x1(in_channels=3) if self.hparams.model_name == 'refine-64' and self.hparams.dataset == 'single_pendulum': self.model = RefineSinglePendulumModel(in_channels=64) if self.hparams.model_name == 'refine-64' and self.hparams.dataset == 'double_pendulum': self.model = RefineDoublePendulumModel(in_channels=64) if self.hparams.model_name == 'refine-64' and self.hparams.dataset == 'circular_motion': self.model = RefineCircularMotionModel(in_channels=64) if self.hparams.model_name == 'refine-64' and self.hparams.dataset == 'swingstick_non_magnetic': self.model = RefineSwingStickNonMagneticModel(in_channels=64) if self.hparams.model_name == 'refine-64' and self.hparams.dataset == 'air_dancer': self.model = RefineAirDancerModel(in_channels=64) if self.hparams.model_name == 'refine-64' and self.hparams.dataset == 'lava_lamp': self.model = RefineLavaLampModel(in_channels=64) if self.hparams.model_name == 'refine-64' and self.hparams.dataset == 'fire': self.model = RefineFireModel(in_channels=64) if self.hparams.model_name == 'refine-64' and self.hparams.dataset == 'elastic_pendulum': self.model = RefineElasticPendulumModel(in_channels=64) if self.hparams.model_name == 'refine-64' and self.hparams.dataset == 'reaction_diffusion': self.model = RefineReactionDiffusionModel(in_channels=64) if 'refine' in self.hparams.model_name and self.hparams.if_test: self.high_dim_model = EncoderDecoder64x1x1(in_channels=3) # loss self.loss_func = nn.MSELoss() def train_forward(self, x): if self.hparams.model_name == 'encoder-decoder' or 'refine' in self.hparams.model_name: output, latent = self.model(x) if self.hparams.model_name == 'encoder-decoder-64': output, latent = self.model(x, x, False) return output, latent def training_step(self, batch, batch_idx): data, target, filepath = batch output, latent = self.train_forward(data) train_loss = self.loss_func(output, target) self.log('train_loss', train_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) return train_loss def validation_step(self, batch, batch_idx): data, target, filepath = batch output, latent = self.train_forward(data) val_loss = self.loss_func(output, target) self.log('val_loss', val_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) return val_loss def test_step(self, batch, batch_idx): if self.hparams.model_name == 'encoder-decoder' or self.hparams.model_name == 'encoder-decoder-64': data, target, filepath = batch if self.hparams.model_name == 'encoder-decoder': output, latent = self.model(data) if self.hparams.model_name == 'encoder-decoder-64': output, latent = self.model(data, data, False) test_loss = self.loss_func(output, target) self.log('test_loss', test_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) # save the output images and latent vectors self.all_filepaths.extend(filepath) for idx in range(data.shape[0]): comparison = torch.cat([data[idx,:, :, :128].unsqueeze(0), data[idx,:, :, 128:].unsqueeze(0), target[idx, :, :, :128].unsqueeze(0), target[idx, :, :, 128:].unsqueeze(0), output[idx, :, :, :128].unsqueeze(0), output[idx, :, :, 128:].unsqueeze(0)]) save_image(comparison.cpu(), os.path.join(self.pred_log_dir, filepath[idx]), nrow=1) latent_tmp = latent[idx].view(1, -1)[0] latent_tmp = latent_tmp.cpu().detach().numpy() self.all_latents.append(latent_tmp) if 'refine' in self.hparams.model_name: data, target, filepath = batch _, latent = self.high_dim_model(data, data, False) latent = latent.squeeze(-1).squeeze(-1) latent_reconstructed, latent_latent = self.model(latent) output, _ = self.high_dim_model(data, latent_reconstructed.unsqueeze(2).unsqueeze(3), True) # calculate losses pixel_reconstruction_loss = self.loss_func(output, target) test_loss = self.loss_func(latent_reconstructed, latent) self.log('test_loss', test_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) self.log('pixel_reconstruction_loss', pixel_reconstruction_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) # save the output images and latent vectors self.all_filepaths.extend(filepath) for idx in range(data.shape[0]): comparison = torch.cat([data[idx, :, :, :128].unsqueeze(0), data[idx, :, :, 128:].unsqueeze(0), target[idx, :, :, :128].unsqueeze(0), target[idx, :, :, 128:].unsqueeze(0), output[idx, :, :, :128].unsqueeze(0), output[idx, :, :, 128:].unsqueeze(0)]) save_image(comparison.cpu(), os.path.join(self.pred_log_dir, filepath[idx]), nrow=1) latent_tmp = latent[idx].view(1, -1)[0] latent_tmp = latent_tmp.cpu().detach().numpy() self.all_latents.append(latent_tmp) # save latent_latent: the latent vector in the refine network latent_latent_tmp = latent_latent[idx].view(1, -1)[0] latent_latent_tmp = latent_latent_tmp.cpu().detach().numpy() self.all_refine_latents.append(latent_latent_tmp) # save latent_reconstructed: the latent vector reconstructed by the entire refine network latent_reconstructed_tmp = latent_reconstructed[idx].view(1, -1)[0] latent_reconstructed_tmp = latent_reconstructed_tmp.cpu().detach().numpy() self.all_reconstructed_latents.append(latent_reconstructed_tmp) def test_save(self): if self.hparams.model_name == 'encoder-decoder' or self.hparams.model_name == 'encoder-decoder-64': np.save(os.path.join(self.var_log_dir, 'ids.npy'), self.all_filepaths) np.save(os.path.join(self.var_log_dir, 'latent.npy'), self.all_latents) if 'refine' in self.hparams.model_name: np.save(os.path.join(self.var_log_dir, 'ids.npy'), self.all_filepaths) np.save(os.path.join(self.var_log_dir, 'latent.npy'), self.all_latents) np.save(os.path.join(self.var_log_dir, 'refine_latent.npy'), self.all_refine_latents) np.save(os.path.join(self.var_log_dir, 'reconstructed_latent.npy'), self.all_reconstructed_latents) def configure_optimizers(self): optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.lr) scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=self.hparams.lr_schedule, gamma=self.hparams.gamma) return [optimizer], [scheduler] def paths_to_tuple(self, paths): new_paths = [] for i in range(len(paths)): tmp = paths[i].split('.')[0].split('_') new_paths.append((int(tmp[0]), int(tmp[1]))) return new_paths def setup(self, stage=None): if stage == 'fit': # for the training of the refine network, we need to have the latent data as the dataset if 'refine' in self.hparams.model_name: high_dim_var_log_dir = self.var_log_dir.replace('refine', 'encoder-decoder') train_data = torch.FloatTensor(np.load(os.path.join(high_dim_var_log_dir+'_train', 'latent.npy'))) train_target = torch.FloatTensor(np.load(os.path.join(high_dim_var_log_dir+'_train', 'latent.npy'))) val_data = torch.FloatTensor(np.load(os.path.join(high_dim_var_log_dir+'_val', 'latent.npy'))) val_target = torch.FloatTensor(np.load(os.path.join(high_dim_var_log_dir+'_val', 'latent.npy'))) train_filepaths = list(np.load(os.path.join(high_dim_var_log_dir+'_train', 'ids.npy'))) val_filepaths = list(np.load(os.path.join(high_dim_var_log_dir+'_val', 'ids.npy'))) # convert the file strings into tuple so that we can use TensorDataset to load everything together train_filepaths = torch.Tensor(self.paths_to_tuple(train_filepaths)) val_filepaths = torch.Tensor(self.paths_to_tuple(val_filepaths)) self.train_dataset = torch.utils.data.TensorDataset(train_data, train_target, train_filepaths) self.val_dataset = torch.utils.data.TensorDataset(val_data, val_target, val_filepaths) else: self.train_dataset = NeuralPhysDataset(data_filepath=self.hparams.data_filepath, flag='train', seed=self.hparams.seed, object_name=self.hparams.dataset) self.val_dataset = NeuralPhysDataset(data_filepath=self.hparams.data_filepath, flag='val', seed=self.hparams.seed, object_name=self.hparams.dataset) if stage == 'test': self.test_dataset = NeuralPhysDataset(data_filepath=self.hparams.data_filepath, flag='test', seed=self.hparams.seed, object_name=self.hparams.dataset) # initialize lists for saving variables and latents during testing self.all_filepaths = [] self.all_latents = [] self.all_refine_latents = [] self.all_reconstructed_latents = [] def train_dataloader(self): train_loader = torch.utils.data.DataLoader(dataset=self.train_dataset, batch_size=self.hparams.train_batch, shuffle=True, **self.kwargs) return train_loader def val_dataloader(self): val_loader = torch.utils.data.DataLoader(dataset=self.val_dataset, batch_size=self.hparams.val_batch, shuffle=False, **self.kwargs) return val_loader def test_dataloader(self): test_loader = torch.utils.data.DataLoader(dataset=self.test_dataset, batch_size=self.hparams.test_batch, shuffle=False, **self.kwargs) return test_loader ================================================ FILE: models_latentpred.py ================================================ import os import glob import torch import shutil import numpy as np from torch import nn import pytorch_lightning as pl import torch.nn.functional as F from torchvision import transforms import torchvision.models as models from collections import OrderedDict from torch.utils.data import DataLoader from torchvision.utils import save_image from dataset import NeuralPhysDataset, NeuralPhysLatentDynamicsDataset from model_utils import (EncoderDecoder, EncoderDecoder64x1x1, RefineDoublePendulumModel, RefineSinglePendulumModel, RefineCircularMotionModel, RefineModelReLU, RefineSwingStickNonMagneticModel, RefineAirDancerModel, RefineLavaLampModel, RefineFireModel, RefineElasticPendulumModel, RefineReactionDiffusionModel, LatentPredModel) def mkdir(folder): if os.path.exists(folder): shutil.rmtree(folder) os.makedirs(folder) def rename_ckpt_for_multi_models(ckpt): renamed_state_dict = OrderedDict() for k, v in ckpt['state_dict'].items(): if k.split('.')[0] == 'model': name = k.replace('model.', '') renamed_state_dict[name] = v return renamed_state_dict class VisLatentDynamicsModel(pl.LightningModule): def __init__(self, lr: float=1e-4, seed: int=1, if_cuda: bool=True, if_test: bool=False, gamma: float=0.5, log_dir: str='logs', train_batch: int=512, val_batch: int=256, test_batch: int=256, num_workers: int=8, model_name: str='encoder-decoder-64', data_filepath: str='data', dataset: str='single_pendulum', lr_schedule: list=[20, 50, 100]) -> None: super().__init__() self.save_hyperparameters() self.kwargs = {'num_workers': self.hparams.num_workers, 'pin_memory': True} if self.hparams.if_cuda else {} # create visualization saving folder if testing self.pred_log_dir = os.path.join(self.hparams.log_dir, 'predictions') self.var_log_dir = os.path.join(self.hparams.log_dir, 'variables') if not self.hparams.if_test: mkdir(self.pred_log_dir) mkdir(self.var_log_dir) self.__build_model() def __build_model(self): # model if self.hparams.model_name == 'latent-prediction' and self.hparams.dataset == 'single_pendulum': self.model = LatentPredModel(in_channels=2) self.high_dim_model = EncoderDecoder64x1x1(in_channels=3) self.refine_model = RefineSinglePendulumModel(in_channels=64) if self.hparams.model_name == 'latent-prediction' and self.hparams.dataset == 'double_pendulum': self.model = LatentPredModel(in_channels=4) self.high_dim_model = EncoderDecoder64x1x1(in_channels=3) self.refine_model = RefineDoublePendulumModel(in_channels=64) if self.hparams.model_name == 'latent-prediction' and self.hparams.dataset == 'elastic_pendulum': self.model = LatentPredModel(in_channels=6) self.high_dim_model = EncoderDecoder64x1x1(in_channels=3) self.refine_model = RefineElasticPendulumModel(in_channels=64) if self.hparams.model_name == 'latent-prediction' and self.hparams.dataset == 'swingstick_non_magnetic': self.model = LatentPredModel(in_channels=4) self.high_dim_model = EncoderDecoder64x1x1(in_channels=3) self.refine_model = RefineSwingStickNonMagneticModel(in_channels=64) if self.hparams.model_name == 'latent-prediction' and self.hparams.dataset == 'air_dancer': self.model = LatentPredModel(in_channels=8) self.high_dim_model = EncoderDecoder64x1x1(in_channels=3) self.refine_model = RefineAirDancerModel(in_channels=64) if self.hparams.model_name == 'latent-prediction' and self.hparams.dataset == 'lava_lamp': self.model = LatentPredModel(in_channels=8) self.high_dim_model = EncoderDecoder64x1x1(in_channels=3) self.refine_model = RefineLavaLampModel(in_channels=64) if self.hparams.model_name == 'latent-prediction' and self.hparams.dataset == 'fire': self.model = LatentPredModel(in_channels=24) self.high_dim_model = EncoderDecoder64x1x1(in_channels=3) self.refine_model = RefineFireModel(in_channels=64) if self.hparams.model_name == 'latent-prediction' and self.hparams.dataset == 'reaction_diffusion': self.model = LatentPredModel(in_channels=2) self.high_dim_model = EncoderDecoder64x1x1(in_channels=3) self.refine_model = RefineReactionDiffusionModel(in_channels=64) # loss self.loss_func = nn.MSELoss() def load_model(self, checkpoint_filepath): # load model for test checkpoint_filepath = glob.glob(os.path.join(checkpoint_filepath, '*.ckpt'))[0] ckpt = torch.load(checkpoint_filepath) ckpt = rename_ckpt_for_multi_models(ckpt) self.model.load_state_dict(ckpt) for p in self.model.parameters(): p.requires_grad = False self.model.eval() def load_high_dim_refine_model(self, high_dim_checkpoint_filepath, refine_checkpoint_filepath): # load high-dim and refine models high_dim_checkpoint_filepath = glob.glob(os.path.join(high_dim_checkpoint_filepath, '*.ckpt'))[0] ckpt = torch.load(high_dim_checkpoint_filepath) ckpt = rename_ckpt_for_multi_models(ckpt) self.high_dim_model.load_state_dict(ckpt) refine_checkpoint_filepath = glob.glob(os.path.join(refine_checkpoint_filepath, '*.ckpt'))[0] ckpt = torch.load(refine_checkpoint_filepath) ckpt = rename_ckpt_for_multi_models(ckpt) self.refine_model.load_state_dict(ckpt) for p in self.high_dim_model.parameters(): p.requires_grad = False self.high_dim_model.eval() for p in self.refine_model.parameters(): p.requires_grad = False self.refine_model.eval() def extract_decoder_from_refine_model(self): _layers = list(self.refine_model.children())[4:] self.refine_model_decoder = torch.nn.Sequential(*_layers) for p in self.refine_model_decoder.parameters(): p.requires_grad = False self.refine_model_decoder.eval() def data_to_state(self, data): # (B, 3, 128, 256) -> (B, 64, 1, 1) -> (B, 64) -> (B, ID) _, latent = self.high_dim_model(data, data, False) latent = latent.squeeze(-1).squeeze(-1) _, state = self.refine_model(latent) return state def training_step(self, batch, batch_idx): data, target, filepath = batch data_state = self.data_to_state(data) target_state = self.data_to_state(target) output_state = self.model(data_state) train_loss = self.loss_func(output_state, target_state) self.log('train_loss', train_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) return train_loss def validation_step(self, batch, batch_idx): data, target, filepath = batch data_state = self.data_to_state(data) target_state = self.data_to_state(target) output_state = self.model(data_state) val_loss = self.loss_func(output_state, target_state) self.log('val_loss', val_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) return val_loss def test_step(self, batch, batch_idx): data, target, target_target, filepath = batch data_state = self.data_to_state(data) target_state = self.data_to_state(target) output_state = self.model(data_state) latent_reconstructed = self.refine_model_decoder(output_state) output, _ = self.high_dim_model(data, latent_reconstructed.unsqueeze(2).unsqueeze(3), True) # calculate losses test_loss = self.loss_func(output_state, target_state) pixel_reconstruction_loss = self.loss_func(output, target_target) self.log('test_loss', test_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) self.log('pixel_reconstruction_loss', pixel_reconstruction_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) # save the output images for idx in range(data.shape[0]): comparison = torch.cat([data[idx, :, :, :128].unsqueeze(0), data[idx, :, :, 128:].unsqueeze(0), target[idx, :, :, :128].unsqueeze(0), target[idx, :, :, 128:].unsqueeze(0), target_target[idx, :, :, :128].unsqueeze(0), target_target[idx, :, :, 128:].unsqueeze(0), output[idx, :, :, :128].unsqueeze(0), output[idx, :, :, 128:].unsqueeze(0)]) save_image(comparison.cpu(), os.path.join(self.pred_log_dir, filepath[idx]), nrow=1) def configure_optimizers(self): optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.parameters()), lr=self.hparams.lr) scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=self.hparams.lr_schedule, gamma=self.hparams.gamma) return [optimizer], [scheduler] def setup(self, stage=None): if stage == 'fit': self.train_dataset = NeuralPhysDataset(data_filepath=self.hparams.data_filepath, flag='train', seed=self.hparams.seed, object_name=self.hparams.dataset) self.val_dataset = NeuralPhysDataset(data_filepath=self.hparams.data_filepath, flag='val', seed=self.hparams.seed, object_name=self.hparams.dataset) if stage == 'test': self.test_dataset = NeuralPhysLatentDynamicsDataset(data_filepath=self.hparams.data_filepath, flag='test', seed=self.hparams.seed, object_name=self.hparams.dataset) def train_dataloader(self): train_loader = torch.utils.data.DataLoader(dataset=self.train_dataset, batch_size=self.hparams.train_batch, shuffle=True, **self.kwargs) return train_loader def val_dataloader(self): val_loader = torch.utils.data.DataLoader(dataset=self.val_dataset, batch_size=self.hparams.val_batch, shuffle=False, **self.kwargs) return val_loader def test_dataloader(self): test_loader = torch.utils.data.DataLoader(dataset=self.test_dataset, batch_size=self.hparams.test_batch, shuffle=False, **self.kwargs) return test_loader ================================================ FILE: pred.py ================================================ """ A. Long-term future prediction (model rollout) 1. encoder-decoder (0, 1 -> 8192-dim latent -> 2', 3'): - feed (2', 3') images as input to predict (4', 5') images ... 2. encoder-decoder-64 (0, 1 -> 64-dim latent -> 2', 3'): - feed (2', 3') images as input to predict (4', 5') images ... 3. encoder-decoder-64 & refine-64 (0, 1 -> id-dim latent -> 2', 3') - feed (2', 3') images as input to predict (4', 5') images ... 4. encoder-decoder-64 & refine-64 hybrid: - use refine-64 model at certain prediction steps B. Long-term future prediction with perturbation (model rollout) """ import os import sys import glob import yaml import json import torch import pprint import shutil import numpy as np from PIL import Image from tqdm import tqdm from munch import munchify from torchvision import transforms from collections import OrderedDict from models import VisDynamicsModel from models_latentpred import VisLatentDynamicsModel from dataset import NeuralPhysDataset from pytorch_lightning import Trainer, seed_everything from pytorch_lightning.loggers import TensorBoardLogger def mkdir(folder): if os.path.exists(folder): shutil.rmtree(folder) os.makedirs(folder) def load_config(filepath): with open(filepath, 'r') as stream: try: trainer_params = yaml.safe_load(stream) return trainer_params except yaml.YAMLError as exc: print(exc) def seed(cfg): torch.manual_seed(cfg.seed) if cfg.if_cuda: torch.cuda.manual_seed(cfg.seed) # uncomment for strict reproducibility # torch.set_deterministic(True) def model_rollout(): config_filepath = str(sys.argv[2]) cfg = load_config(filepath=config_filepath) pprint.pprint(cfg) cfg = munchify(cfg) seed(cfg) seed_everything(cfg.seed) log_dir = '_'.join([cfg.log_dir, cfg.dataset, cfg.model_name, str(cfg.seed)]) model = VisDynamicsModel(lr=cfg.lr, seed=cfg.seed, if_cuda=cfg.if_cuda, if_test=True, gamma=cfg.gamma, log_dir=log_dir, train_batch=cfg.train_batch, val_batch=cfg.val_batch, test_batch=cfg.test_batch, num_workers=cfg.num_workers, model_name=cfg.model_name, data_filepath=cfg.data_filepath, dataset=cfg.dataset, lr_schedule=cfg.lr_schedule) # load model if cfg.model_name == 'encoder-decoder' or cfg.model_name == 'encoder-decoder-64': checkpoint_filepath = str(sys.argv[3]) checkpoint_filepath = glob.glob(os.path.join(checkpoint_filepath, '*.ckpt'))[0] ckpt = torch.load(checkpoint_filepath) model.load_state_dict(ckpt['state_dict']) if 'refine' in cfg.model_name: checkpoint_filepath = str(sys.argv[4]) checkpoint_filepath = glob.glob(os.path.join(checkpoint_filepath, '*.ckpt'))[0] ckpt = torch.load(checkpoint_filepath) ckpt = rename_ckpt_for_multi_models(ckpt) model.model.load_state_dict(ckpt) high_dim_checkpoint_filepath = str(sys.argv[3]) high_dim_checkpoint_filepath = glob.glob(os.path.join(high_dim_checkpoint_filepath, '*.ckpt'))[0] ckpt = torch.load(high_dim_checkpoint_filepath) ckpt = rename_ckpt_for_multi_models(ckpt) model.high_dim_model.load_state_dict(ckpt) model = model.to('cuda') model.eval() model.freeze() # get all the test video ids data_filepath_base = os.path.join(cfg.data_filepath, cfg.dataset) with open(os.path.join('../datainfo', cfg.dataset, f'data_split_dict_{cfg.seed}.json'), 'r') as file: seq_dict = json.load(file) test_vid_ids = seq_dict['test'] pred_len = int(sys.argv[6]) long_term_folder = os.path.join(log_dir, 'prediction_long_term', 'model_rollout') loss_dict = {} if cfg.model_name == 'encoder-decoder' or cfg.model_name == 'encoder-decoder-64': for p_vid_idx in tqdm(test_vid_ids): vid_filepath = os.path.join(data_filepath_base, str(p_vid_idx)) total_num_frames = len(os.listdir(vid_filepath)) suf = os.listdir(vid_filepath)[0].split('.')[-1] data = None saved_folder = os.path.join(long_term_folder, str(p_vid_idx)) mkdir(saved_folder) loss_lst = [] for start_frame_idx in range(total_num_frames - 3): if start_frame_idx % 2 != 0: continue # take the initial input from ground truth data if start_frame_idx % pred_len == 0: data = [get_data(os.path.join(vid_filepath, f'{start_frame_idx}.{suf}')), get_data(os.path.join(vid_filepath, f'{start_frame_idx+1}.{suf}'))] data = (torch.cat(data, 2)).unsqueeze(0) # get the target target = [get_data(os.path.join(vid_filepath, f'{start_frame_idx+2}.{suf}')), get_data(os.path.join(vid_filepath, f'{start_frame_idx+3}.{suf}'))] target = (torch.cat(target, 2)).unsqueeze(0) # feed into the model if cfg.model_name == 'encoder-decoder': output, latent = model.model(data.cuda()) if cfg.model_name == 'encoder-decoder-64': output, latent = model.model(data.cuda(), data.cuda(), False) # compute loss loss_lst.append(float(model.loss_func(output, target.cuda()).cpu().detach().numpy())) # save (2', 3'), (4', 5'), ... img = tensor_to_img(output[0, :, :, :128]) img.save(os.path.join(saved_folder, f'{start_frame_idx+2}.{suf}')) img = tensor_to_img(output[0, :, :, 128:]) img.save(os.path.join(saved_folder, f'{start_frame_idx+3}.{suf}')) # the output becomes the input data in the next iteration data = torch.tensor(output.cpu().detach().numpy()).float() loss_dict[p_vid_idx] = loss_lst # save the test loss for all the testing videos with open(os.path.join(long_term_folder, 'test_loss.json'), 'w') as file: json.dump(loss_dict, file, indent=4) if 'refine' in cfg.model_name: for p_vid_idx in tqdm(test_vid_ids): vid_filepath = os.path.join(data_filepath_base, str(p_vid_idx)) total_num_frames = len(os.listdir(vid_filepath)) suf = os.listdir(vid_filepath)[0].split('.')[-1] data = None saved_folder = os.path.join(long_term_folder, str(p_vid_idx)) mkdir(saved_folder) loss_lst = [] for start_frame_idx in range(total_num_frames - 3): if start_frame_idx % 2 != 0: continue # take the initial input from ground truth data if start_frame_idx % pred_len == 0: data = [get_data(os.path.join(vid_filepath, f'{start_frame_idx}.{suf}')), get_data(os.path.join(vid_filepath, f'{start_frame_idx+1}.{suf}'))] data = (torch.cat(data, 2)).unsqueeze(0) # get the target target = [get_data(os.path.join(vid_filepath, f'{start_frame_idx+2}.{suf}')), get_data(os.path.join(vid_filepath, f'{start_frame_idx+3}.{suf}'))] target = (torch.cat(target, 2)).unsqueeze(0) # feed into the model _, latent = model.high_dim_model(data.cuda(), data.cuda(), False) latent = latent.squeeze(-1).squeeze(-1) latent_reconstructed, latent_latent = model.model(latent) output, _ = model.high_dim_model(data.cuda(), latent_reconstructed.unsqueeze(2).unsqueeze(3), True) # compute loss loss_lst.append(float(model.loss_func(output, target.cuda()).cpu().detach().numpy())) # save (2', 3'), (4', 5'), ... img = tensor_to_img(output[0, :, :, :128]) img.save(os.path.join(saved_folder, f'{start_frame_idx+2}.{suf}')) img = tensor_to_img(output[0, :, :, 128:]) img.save(os.path.join(saved_folder, f'{start_frame_idx+3}.{suf}')) # the output becomes the input data in the next iteration data = torch.tensor(output.cpu().detach().numpy()).float() loss_dict[p_vid_idx] = loss_lst # save the test loss for all the testing videos with open(os.path.join(long_term_folder, 'test_loss.json'), 'w') as file: json.dump(loss_dict, file, indent=4) def model_rollout_hybrid(step): config_filepath = str(sys.argv[2]) cfg = load_config(filepath=config_filepath) pprint.pprint(cfg) cfg = munchify(cfg) seed(cfg) seed_everything(cfg.seed) if 'refine' not in cfg.model_name: assert False, "the hybrid scheme is only supported with refine model..." log_dir = '_'.join([cfg.log_dir, cfg.dataset, cfg.model_name, str(cfg.seed)]) model = VisDynamicsModel(lr=cfg.lr, seed=cfg.seed, if_cuda=cfg.if_cuda, if_test=True, gamma=cfg.gamma, log_dir=log_dir, train_batch=cfg.train_batch, val_batch=cfg.val_batch, test_batch=cfg.test_batch, num_workers=cfg.num_workers, model_name=cfg.model_name, data_filepath=cfg.data_filepath, dataset=cfg.dataset, lr_schedule=cfg.lr_schedule) # load model checkpoint_filepath = str(sys.argv[4]) checkpoint_filepath = glob.glob(os.path.join(checkpoint_filepath, '*.ckpt'))[0] ckpt = torch.load(checkpoint_filepath) ckpt = rename_ckpt_for_multi_models(ckpt) model.model.load_state_dict(ckpt) high_dim_checkpoint_filepath = str(sys.argv[3]) high_dim_checkpoint_filepath = glob.glob(os.path.join(high_dim_checkpoint_filepath, '*.ckpt'))[0] ckpt = torch.load(high_dim_checkpoint_filepath) ckpt = rename_ckpt_for_multi_models(ckpt) model.high_dim_model.load_state_dict(ckpt) model = model.to('cuda') model.eval() model.freeze() # get all the test video ids data_filepath_base = os.path.join(cfg.data_filepath, cfg.dataset) with open(os.path.join('../datainfo', cfg.dataset, f'data_split_dict_{cfg.seed}.json'), 'r') as file: seq_dict = json.load(file) test_vid_ids = seq_dict['test'] pred_len = int(sys.argv[6]) long_term_folder = os.path.join(log_dir, 'prediction_long_term', f'hybrid_rollout_{step}') loss_dict = {} for p_vid_idx in tqdm(test_vid_ids): vid_filepath = os.path.join(data_filepath_base, str(p_vid_idx)) total_num_frames = len(os.listdir(vid_filepath)) suf = os.listdir(vid_filepath)[0].split('.')[-1] data = None saved_folder = os.path.join(long_term_folder, str(p_vid_idx)) mkdir(saved_folder) loss_lst = [] for start_frame_idx in range(total_num_frames - 3): if start_frame_idx % 2 != 0: continue # take the initial input from ground truth data if start_frame_idx % pred_len == 0: data = [get_data(os.path.join(vid_filepath, f'{start_frame_idx}.{suf}')), get_data(os.path.join(vid_filepath, f'{start_frame_idx+1}.{suf}'))] data = (torch.cat(data, 2)).unsqueeze(0) # get the target target = [get_data(os.path.join(vid_filepath, f'{start_frame_idx+2}.{suf}')), get_data(os.path.join(vid_filepath, f'{start_frame_idx+3}.{suf}'))] target = (torch.cat(target, 2)).unsqueeze(0) # feed into the model if (start_frame_idx + 2) % (2 * step + 2) == 0: _, latent = model.high_dim_model(data.cuda(), data.cuda(), False) latent = latent.squeeze(-1).squeeze(-1) latent_reconstructed, latent_latent = model.model(latent) output, _ = model.high_dim_model(data.cuda(), latent_reconstructed.unsqueeze(2).unsqueeze(3), True) else: output, _ = model.high_dim_model(data.cuda(), data.cuda(), False) # compute loss loss_lst.append(float(model.loss_func(output, target.cuda()).cpu().detach().numpy())) # save (2', 3'), (4', 5'), ... img = tensor_to_img(output[0, :, :, :128]) img.save(os.path.join(saved_folder, f'{start_frame_idx+2}.{suf}')) img = tensor_to_img(output[0, :, :, 128:]) img.save(os.path.join(saved_folder, f'{start_frame_idx+3}.{suf}')) # the output becomes the input data in the next iteration data = torch.tensor(output.cpu().detach().numpy()).float() loss_dict[p_vid_idx] = loss_lst # save the test loss for all the testing videos with open(os.path.join(long_term_folder, 'test_loss.json'), 'w') as file: json.dump(loss_dict, file, indent=4) def model_rollout_perturb(perturb_type, perturb_level): config_filepath = str(sys.argv[2]) cfg = load_config(filepath=config_filepath) pprint.pprint(cfg) cfg = munchify(cfg) seed(cfg) seed_everything(cfg.seed) log_dir = '_'.join([cfg.log_dir, cfg.dataset, cfg.model_name, str(cfg.seed)]) model = VisDynamicsModel(lr=cfg.lr, seed=cfg.seed, if_cuda=cfg.if_cuda, if_test=True, gamma=cfg.gamma, log_dir=log_dir, train_batch=cfg.train_batch, val_batch=cfg.val_batch, test_batch=cfg.test_batch, num_workers=cfg.num_workers, model_name=cfg.model_name, data_filepath=cfg.data_filepath, dataset=cfg.dataset, lr_schedule=cfg.lr_schedule) # load model if cfg.model_name == 'encoder-decoder' or cfg.model_name == 'encoder-decoder-64': checkpoint_filepath = str(sys.argv[3]) checkpoint_filepath = glob.glob(os.path.join(checkpoint_filepath, '*.ckpt'))[0] ckpt = torch.load(checkpoint_filepath) model.load_state_dict(ckpt['state_dict']) if 'refine' in cfg.model_name: checkpoint_filepath = str(sys.argv[4]) checkpoint_filepath = glob.glob(os.path.join(checkpoint_filepath, '*.ckpt'))[0] ckpt = torch.load(checkpoint_filepath) ckpt = rename_ckpt_for_multi_models(ckpt) model.model.load_state_dict(ckpt) high_dim_checkpoint_filepath = str(sys.argv[3]) high_dim_checkpoint_filepath = glob.glob(os.path.join(high_dim_checkpoint_filepath, '*.ckpt'))[0] ckpt = torch.load(high_dim_checkpoint_filepath) ckpt = rename_ckpt_for_multi_models(ckpt) model.high_dim_model.load_state_dict(ckpt) model = model.to('cuda') model.eval() model.freeze() # get all the test video ids data_filepath_base = os.path.join(cfg.data_filepath, cfg.dataset) with open(os.path.join('../datainfo', cfg.dataset, f'data_split_dict_{cfg.seed}.json'), 'r') as file: seq_dict = json.load(file) test_vid_ids = seq_dict['test'] pred_len = int(sys.argv[6]) long_term_folder = os.path.join(log_dir, 'prediction_long_term', f'model_rollout_perturb_{perturb_type}_{perturb_level}') loss_dict = {} if cfg.model_name == 'encoder-decoder' or cfg.model_name == 'encoder-decoder-64': for p_vid_idx in tqdm(test_vid_ids): vid_filepath = os.path.join(data_filepath_base, str(p_vid_idx)) total_num_frames = len(os.listdir(vid_filepath)) suf = os.listdir(vid_filepath)[0].split('.')[-1] data = None saved_folder = os.path.join(long_term_folder, str(p_vid_idx)) mkdir(saved_folder) loss_lst = [] for start_frame_idx in range(total_num_frames - 3): if start_frame_idx % 2 != 0: continue # take the initial input from ground truth data if start_frame_idx % pred_len == 0: data = [get_data_perturb(os.path.join(vid_filepath, f'{start_frame_idx}.{suf}'), perturb_type, perturb_level), get_data_perturb(os.path.join(vid_filepath, f'{start_frame_idx+1}.{suf}'), perturb_type, perturb_level)] data = (torch.cat(data, 2)).unsqueeze(0) img = tensor_to_img(data[0, :, :, :128]) img.save(os.path.join(saved_folder, f'{start_frame_idx}.{suf}')) img = tensor_to_img(data[0, :, :, 128:]) img.save(os.path.join(saved_folder, f'{start_frame_idx+1}.{suf}')) # get the target target = [get_data(os.path.join(vid_filepath, f'{start_frame_idx+2}.{suf}')), get_data(os.path.join(vid_filepath, f'{start_frame_idx+3}.{suf}'))] target = (torch.cat(target, 2)).unsqueeze(0) # feed into the model if cfg.model_name == 'encoder-decoder': output, latent = model.model(data.cuda()) if cfg.model_name == 'encoder-decoder-64': output, latent = model.model(data.cuda(), data.cuda(), False) # compute loss loss_lst.append(float(model.loss_func(output, target.cuda()).cpu().detach().numpy())) # save (2', 3'), (4', 5'), ... img = tensor_to_img(output[0, :, :, :128]) img.save(os.path.join(saved_folder, f'{start_frame_idx+2}.{suf}')) img = tensor_to_img(output[0, :, :, 128:]) img.save(os.path.join(saved_folder, f'{start_frame_idx+3}.{suf}')) # the output becomes the input data in the next iteration data = torch.tensor(output.cpu().detach().numpy()).float() loss_dict[p_vid_idx] = loss_lst # save the test loss for all the testing videos with open(os.path.join(long_term_folder, 'test_loss.json'), 'w') as file: json.dump(loss_dict, file, indent=4) if 'refine' in cfg.model_name: for p_vid_idx in tqdm(test_vid_ids): vid_filepath = os.path.join(data_filepath_base, str(p_vid_idx)) total_num_frames = len(os.listdir(vid_filepath)) suf = os.listdir(vid_filepath)[0].split('.')[-1] data = None saved_folder = os.path.join(long_term_folder, str(p_vid_idx)) mkdir(saved_folder) loss_lst = [] for start_frame_idx in range(total_num_frames - 3): if start_frame_idx % 2 != 0: continue # take the initial input from ground truth data if start_frame_idx % pred_len == 0: data = [get_data_perturb(os.path.join(vid_filepath, f'{start_frame_idx}.{suf}'), perturb_type, perturb_level), get_data_perturb(os.path.join(vid_filepath, f'{start_frame_idx+1}.{suf}'), perturb_type, perturb_level)] data = (torch.cat(data, 2)).unsqueeze(0) img = tensor_to_img(data[0, :, :, :128]) img.save(os.path.join(saved_folder, f'{start_frame_idx}.{suf}')) img = tensor_to_img(data[0, :, :, 128:]) img.save(os.path.join(saved_folder, f'{start_frame_idx+1}.{suf}')) # get the target target = [get_data(os.path.join(vid_filepath, f'{start_frame_idx+2}.{suf}')), get_data(os.path.join(vid_filepath, f'{start_frame_idx+3}.{suf}'))] target = (torch.cat(target, 2)).unsqueeze(0) # feed into the model _, latent = model.high_dim_model(data.cuda(), data.cuda(), False) latent = latent.squeeze(-1).squeeze(-1) latent_reconstructed, latent_latent = model.model(latent) output, _ = model.high_dim_model(data.cuda(), latent_reconstructed.unsqueeze(2).unsqueeze(3), True) # compute loss loss_lst.append(float(model.loss_func(output, target.cuda()).cpu().detach().numpy())) # save (2', 3'), (4', 5'), ... img = tensor_to_img(output[0, :, :, :128]) img.save(os.path.join(saved_folder, f'{start_frame_idx+2}.{suf}')) img = tensor_to_img(output[0, :, :, 128:]) img.save(os.path.join(saved_folder, f'{start_frame_idx+3}.{suf}')) # the output becomes the input data in the next iteration data = torch.tensor(output.cpu().detach().numpy()).float() loss_dict[p_vid_idx] = loss_lst # save the test loss for all the testing videos with open(os.path.join(long_term_folder, 'test_loss.json'), 'w') as file: json.dump(loss_dict, file, indent=4) def rename_ckpt_for_multi_models(ckpt): renamed_state_dict = OrderedDict() for k, v in ckpt['state_dict'].items(): if 'high_dim_model' in k: name = k.replace('high_dim_model.', '') else: name = k.replace('model.', '') renamed_state_dict[name] = v return renamed_state_dict def get_data(filepath): data = Image.open(filepath) data = data.resize((128, 128)) data = np.array(data) data = torch.tensor(data / 255.0) data = data.permute(2, 0, 1).float() return data def get_data_perturb(filepath, perturb_type, perturb_level): data = Image.open(filepath) data = data.resize((128, 128)) data = np.array(data) bg_color = np.array([215, 205, 192]) rng = np.random.RandomState(int(filepath.split('/')[-2])) new_bg_color = rng.randint(256, size=3) if perturb_type == 'background_replace': for i in range(2**(perturb_level-1)): for j in range(2**(perturb_level-1)): if np.array_equal(data[i, j], bg_color): data[i, j] = new_bg_color elif perturb_type == 'background_cover': for i in range(2**(perturb_level-1)): for j in range(2**(perturb_level-1)): data[i, j] = new_bg_color elif perturb_type == 'white_noise': sigma = 255.0 * (2**(perturb_level-1) / 128) ** 2 data = data + rng.normal(0, sigma, data.shape) else: pass data = torch.tensor(data / 255.0) data = data.permute(2, 0, 1).float() return data # out_tensor: 3 x 128 x 128 -> 128 x 128 x 3 def tensor_to_img(out_tensor): return transforms.ToPILImage()(out_tensor).convert("RGB") if __name__ == '__main__': if str(sys.argv[1]) == 'model-rollout': model_rollout() elif 'hybrid' in str(sys.argv[1]): step = int(sys.argv[1].split('-')[-1]) model_rollout_hybrid(step) elif str(sys.argv[1]) == 'latent-prediction': latent_prediction() elif 'perturb' in str(sys.argv[1]): perturb_type = str(sys.argv[1].split('-')[-2]) perturb_level = int(sys.argv[1].split('-')[-1]) model_rollout_perturb(perturb_type, perturb_level) else: assert False, "prediction scheme is not supported..." ================================================ FILE: requirements.txt ================================================ absl-py==0.7.1 aiohttp==3.7.4.post0 aiohttp-cors==0.7.0 aioredis==1.3.1 astor==0.7.1 astunparse==1.6.3 async-timeout==3.0.1 atari-py==0.1.15 atomicwrites==1.3.0 attrs==19.3.0 audioread==2.1.6 backcall==0.1.0 backports.shutil-get-terminal-size==1.0.0 bleach==1.5.0 blessings==1.7 cachetools==4.1.1 cattrs==1.0.0 certifi==2019.3.9 chardet==3.0.4 chumpy==0.70 click==8.0.1 cloudpickle==0.8.1 colorama==0.4.4 contextvars==2.4 cycler==0.10.0 Cython==0.29.22 dataclasses==0.8 decorator==4.4.0 deepdish==0.3.6 defusedxml==0.6.0 dill==0.2.9 Django==2.2.1 docopt==0.6.2 entrypoints==0.3 filelock==3.0.12 Flask==2.0.1 flatbuffers==1.12 fsspec==0.8.4 future==0.17.1 gast==0.3.3 glob2==0.6 google-api-core==1.28.0 google-auth==1.30.1 google-auth-oauthlib==0.4.1 google-pasta==0.2.0 googleapis-common-protos==1.53.0 gpustat==0.6.0 grpcio==1.34.0 gym==0.12.5 h5py==2.10.0 hiredis==2.0.0 html5lib==0.9999999 idna==2.8 idna-ssl==1.1.0 image==1.5.25 imageio==2.5.0 immutables==0.15 importlib-metadata==4.3.0 imutils==0.5.2 ipdb==0.12 ipykernel==5.1.0 ipython==7.4.0 ipython-genutils==0.2.0 ipywidgets==7.4.2 itsdangerous==2.0.1 jedi==0.13.3 Jinja2==3.0.1 joblib==0.13.2 jsonschema==3.0.1 jupyter==1.0.0 jupyter-client==5.2.4 jupyter-console==6.0.0 jupyter-core==4.4.0 kaleido==0.2.1 Keras-Preprocessing==1.1.2 kiwisolver==1.0.1 librosa==0.6.3 llvmlite==0.28.0 Markdown==3.1 MarkupSafe==2.0.1 matplotlib==2.2.2 mistune==0.8.4 mock==3.0.5 more-itertools==7.0.0 mpi4py==3.0.3 msgpack==1.0.2 multidict==5.1.0 munch==2.5.0 munkres==1.0.12 nbconvert==5.4.1 nbformat==4.4.0 networkx==2.3 nibabel==2.4.0 notebook==5.7.8 numba==0.43.1 numexpr==2.6.9 numpy==1.19.4 nvidia-ml-py3==7.352.0 oauthlib==3.1.0 opencensus==0.7.13 opencensus-context==0.1.2 opencv-python==4.5.2.52 opt-einsum==3.3.0 packaging==20.9 pandas==0.24.2 pandocfilters==1.4.2 parso==0.3.4 pexpect==4.6.0 pickleshare==0.7.5 Pillow==7.1.2 pkg-resources==0.0.0 plotly==4.14.3 pluggy==0.11.0 prometheus-client==0.10.1 prompt-toolkit==2.0.9 protobuf==3.17.1 psutil==5.8.0 ptyprocess==0.6.0 py==1.8.0 py-spy==0.3.7 pyasn1==0.4.8 pyasn1-modules==0.2.8 pyglet==1.3.2 Pygments==2.3.1 PyOpenGL==3.1.0 pyparsing==2.3.1 pyrsistent==0.14.11 pytest==3.10.1 python-dateutil==2.8.0 pytorch-lightning==1.0.8 pytz==2019.1 PyWavelets==1.0.3 PyYAML==5.1 pyzmq==18.0.1 qtconsole==4.4.3 ray==1.3.0 redis==3.5.3 reprint==0.5.2 requests==2.21.0 requests-oauthlib==1.3.0 resampy==0.2.1 retrying==1.3.3 rsa==4.6 scikit-image==0.15.0 scikit-learn==0.20.3 scipy==1.4.1 seaborn==0.9.0 Send2Trash==1.5.0 Shapely==1.6.4.post2 six==1.16.0 sklearn==0.0 sqlparse==0.3.0 stable-baselines==2.5.1 tables==3.5.1 tabulate==0.8.9 tensorboard==2.2.2 tensorboard-plugin-wit==1.7.0 tensorboardX==2.2 termcolor==1.1.0 terminado==0.8.2 testpath==0.4.2 --find-links https://download.pytorch.org/whl/torch_stable.html torch==1.7.0 torchfile==0.1.0 torchvision==0.8.1 tornado==6.0.2 tqdm==4.54.1 traitlets==4.3.2 typing-extensions==3.7.4.3 urllib3==1.24.1 visdom==0.1.8.8 wcwidth==0.1.7 webencodings==0.5.1 websocket-client==0.56.0 Werkzeug==2.0.1 widgetsnbextension==3.4.2 wrapt==1.12.1 yarl==1.6.3 zipp==3.4.1 zmq==0.0.0 ================================================ FILE: scripts/encoder_decoder_64_eval.sh ================================================ #!/bin/bash dataset=$1 gpu=$2 echo "=============================================================================================" echo "============== Evaluating encoder-decoder-64 model on: $dataset (gpu id: $gpu) ==============" echo "=============================================================================================" screen -S eval-"$dataset"-64 -dm bash -c "CUDA_VISIBLE_DEVICES="$gpu" python ../eval.py ../configs/"$dataset"/model64/config1.yaml ./logs_"$dataset"_encoder-decoder-64_1/lightning_logs/checkpoints NA eval-eval NA; \ CUDA_VISIBLE_DEVICES="$gpu" python ../eval.py ../configs/"$dataset"/model64/config2.yaml ./logs_"$dataset"_encoder-decoder-64_2/lightning_logs/checkpoints NA eval-eval NA; \ CUDA_VISIBLE_DEVICES="$gpu" python ../eval.py ../configs/"$dataset"/model64/config3.yaml ./logs_"$dataset"_encoder-decoder-64_3/lightning_logs/checkpoints NA eval-eval NA; \ exec sh"; ================================================ FILE: scripts/encoder_decoder_64_eval_gather.sh ================================================ #!/bin/bash dataset=$1 gpu=$2 echo "=========================================================================================================" echo "============== Evaluating (Gathering) encoder-decoder-64 model on: $dataset (gpu id: $gpu) ==============" echo "=========================================================================================================" screen -S eval-"$dataset"-64 -dm bash -c "CUDA_VISIBLE_DEVICES="$gpu" python ../eval.py ../configs/"$dataset"/model64/config1.yaml ./logs_"$dataset"_encoder-decoder-64_1/lightning_logs/checkpoints NA eval-train NA; \ CUDA_VISIBLE_DEVICES="$gpu" python ../eval.py ../configs/"$dataset"/model64/config2.yaml ./logs_"$dataset"_encoder-decoder-64_2/lightning_logs/checkpoints NA eval-train NA; \ CUDA_VISIBLE_DEVICES="$gpu" python ../eval.py ../configs/"$dataset"/model64/config3.yaml ./logs_"$dataset"_encoder-decoder-64_3/lightning_logs/checkpoints NA eval-train NA; \ exec sh"; ================================================ FILE: scripts/encoder_decoder_64_long_term_model_rollout.sh ================================================ #!/bin/bash dataset=$1 gpu=$2 echo "==========================================================================================================" echo "============== Long-term model rollout encoder-decoder-64 model on: $dataset (gpu id: $gpu) ==============" echo "==========================================================================================================" screen -S eval-"$dataset"-longterm-modelrollout -dm bash -c "CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout ../configs/"$dataset"/model64/config1.yaml ./logs_"$dataset"_encoder-decoder-64_1/lightning_logs/checkpoints NA NA 60; \ CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout ../configs/"$dataset"/model64/config2.yaml ./logs_"$dataset"_encoder-decoder-64_2/lightning_logs/checkpoints NA NA 60; \ CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout ../configs/"$dataset"/model64/config3.yaml ./logs_"$dataset"_encoder-decoder-64_3/lightning_logs/checkpoints NA NA 60; \ exec sh"; ================================================ FILE: scripts/encoder_decoder_64_long_term_model_rollout_perturb_all.sh ================================================ #!/bin/bash dataset=$1 gpu=$2 echo "==========================================================================================================" echo "============== Long-term model rollout encoder-decoder-64 model on: $dataset (gpu id: $gpu) ==============" echo "==========================================================================================================" screen -S eval-"$dataset"-longterm-modelrollout -dm bash -c "CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-background_cover-5 ../configs/"$dataset"/model64/config1.yaml ./logs_"$dataset"_encoder-decoder-64_1/lightning_logs/checkpoints NA NA 60; \ CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-background_cover-5 ../configs/"$dataset"/model64/config2.yaml ./logs_"$dataset"_encoder-decoder-64_2/lightning_logs/checkpoints NA NA 60; \ CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-background_cover-5 ../configs/"$dataset"/model64/config3.yaml ./logs_"$dataset"_encoder-decoder-64_3/lightning_logs/checkpoints NA NA 60; \ CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-background_cover-7 ../configs/"$dataset"/model64/config1.yaml ./logs_"$dataset"_encoder-decoder-64_1/lightning_logs/checkpoints NA NA 60; \ CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-background_cover-7 ../configs/"$dataset"/model64/config2.yaml ./logs_"$dataset"_encoder-decoder-64_2/lightning_logs/checkpoints NA NA 60; \ CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-background_cover-7 ../configs/"$dataset"/model64/config3.yaml ./logs_"$dataset"_encoder-decoder-64_3/lightning_logs/checkpoints NA NA 60; \ CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-background_cover-8 ../configs/"$dataset"/model64/config1.yaml ./logs_"$dataset"_encoder-decoder-64_1/lightning_logs/checkpoints NA NA 60; \ CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-background_cover-8 ../configs/"$dataset"/model64/config2.yaml ./logs_"$dataset"_encoder-decoder-64_2/lightning_logs/checkpoints NA NA 60; \ CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-background_cover-8 ../configs/"$dataset"/model64/config3.yaml ./logs_"$dataset"_encoder-decoder-64_3/lightning_logs/checkpoints NA NA 60; \ CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-background_replace-5 ../configs/"$dataset"/model64/config1.yaml ./logs_"$dataset"_encoder-decoder-64_1/lightning_logs/checkpoints NA NA 60; \ CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-background_replace-5 ../configs/"$dataset"/model64/config2.yaml ./logs_"$dataset"_encoder-decoder-64_2/lightning_logs/checkpoints NA NA 60; \ CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-background_replace-5 ../configs/"$dataset"/model64/config3.yaml ./logs_"$dataset"_encoder-decoder-64_3/lightning_logs/checkpoints NA NA 60; \ CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-background_replace-7 ../configs/"$dataset"/model64/config1.yaml ./logs_"$dataset"_encoder-decoder-64_1/lightning_logs/checkpoints NA NA 60; \ CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-background_replace-7 ../configs/"$dataset"/model64/config2.yaml ./logs_"$dataset"_encoder-decoder-64_2/lightning_logs/checkpoints NA NA 60; \ CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-background_replace-7 ../configs/"$dataset"/model64/config3.yaml ./logs_"$dataset"_encoder-decoder-64_3/lightning_logs/checkpoints NA NA 60; \ CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-background_replace-8 ../configs/"$dataset"/model64/config1.yaml ./logs_"$dataset"_encoder-decoder-64_1/lightning_logs/checkpoints NA NA 60; \ CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-background_replace-8 ../configs/"$dataset"/model64/config2.yaml ./logs_"$dataset"_encoder-decoder-64_2/lightning_logs/checkpoints NA NA 60; \ CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-background_replace-8 ../configs/"$dataset"/model64/config3.yaml ./logs_"$dataset"_encoder-decoder-64_3/lightning_logs/checkpoints NA NA 60; \ CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-white_noise-5 ../configs/"$dataset"/model64/config1.yaml ./logs_"$dataset"_encoder-decoder-64_1/lightning_logs/checkpoints NA NA 60; \ CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-white_noise-5 ../configs/"$dataset"/model64/config2.yaml ./logs_"$dataset"_encoder-decoder-64_2/lightning_logs/checkpoints NA NA 60; \ CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-white_noise-5 ../configs/"$dataset"/model64/config3.yaml ./logs_"$dataset"_encoder-decoder-64_3/lightning_logs/checkpoints NA NA 60; \ CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-white_noise-7 ../configs/"$dataset"/model64/config1.yaml ./logs_"$dataset"_encoder-decoder-64_1/lightning_logs/checkpoints NA NA 60; \ CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-white_noise-7 ../configs/"$dataset"/model64/config2.yaml ./logs_"$dataset"_encoder-decoder-64_2/lightning_logs/checkpoints NA NA 60; \ CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-white_noise-7 ../configs/"$dataset"/model64/config3.yaml ./logs_"$dataset"_encoder-decoder-64_3/lightning_logs/checkpoints NA NA 60; \ CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-white_noise-8 ../configs/"$dataset"/model64/config1.yaml ./logs_"$dataset"_encoder-decoder-64_1/lightning_logs/checkpoints NA NA 60; \ CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-white_noise-8 ../configs/"$dataset"/model64/config2.yaml ./logs_"$dataset"_encoder-decoder-64_2/lightning_logs/checkpoints NA NA 60; \ CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-white_noise-8 ../configs/"$dataset"/model64/config3.yaml ./logs_"$dataset"_encoder-decoder-64_3/lightning_logs/checkpoints NA NA 60; \ exec sh"; ================================================ FILE: scripts/encoder_decoder_64_train.sh ================================================ #!/bin/bash dataset=$1 gpu=$2 echo "===========================================================================================" echo "============== Training encoder-decoder-64 model on: $dataset (gpu id: $gpu) ==============" echo "===========================================================================================" screen -S train-"$dataset"-64 -dm bash -c "CUDA_VISIBLE_DEVICES="$gpu" python ../main.py ../configs/"$dataset"/model64/config1.yaml; \ CUDA_VISIBLE_DEVICES="$gpu" python ../main.py ../configs/"$dataset"/model64/config2.yaml; \ CUDA_VISIBLE_DEVICES="$gpu" python ../main.py ../configs/"$dataset"/model64/config3.yaml; \ exec sh"; ================================================ FILE: scripts/encoder_decoder_estimate_dimension.sh ================================================ #!/bin/bash dataset=$1 echo "============================================================================================" echo "========== Estimating intrinsic dimension from encoder-decoder model on: $dataset ==========" echo "============================================================================================" screen -S eval-"$dataset"-dimension -dm bash -c "OMP_NUM_THREADS=4 python ../analysis/eval_intrinsic_dimension.py ../configs/"$dataset"/model/config1.yaml model-latent NA; \ OMP_NUM_THREADS=4 python ../analysis/eval_intrinsic_dimension.py ../configs/"$dataset"/model/config2.yaml model-latent NA; \ OMP_NUM_THREADS=4 python ../analysis/eval_intrinsic_dimension.py ../configs/"$dataset"/model/config3.yaml model-latent NA; \ exec sh"; ================================================ FILE: scripts/encoder_decoder_eval.sh ================================================ #!/bin/bash dataset=$1 gpu=$2 echo "==========================================================================================" echo "============== Evaluating encoder-decoder model on: $dataset (gpu id: $gpu) ==============" echo "==========================================================================================" screen -S eval-"$dataset" -dm bash -c "CUDA_VISIBLE_DEVICES="$gpu" python ../eval.py ../configs/"$dataset"/model/config1.yaml ./logs_"$dataset"_encoder-decoder_1/lightning_logs/checkpoints NA eval-eval NA; \ CUDA_VISIBLE_DEVICES="$gpu" python ../eval.py ../configs/"$dataset"/model/config2.yaml ./logs_"$dataset"_encoder-decoder_2/lightning_logs/checkpoints NA eval-eval NA; \ CUDA_VISIBLE_DEVICES="$gpu" python ../eval.py ../configs/"$dataset"/model/config3.yaml ./logs_"$dataset"_encoder-decoder_3/lightning_logs/checkpoints NA eval-eval NA; \ exec sh"; ================================================ FILE: scripts/encoder_decoder_eval_gather.sh ================================================ #!/bin/bash dataset=$1 gpu=$2 echo "======================================================================================================" echo "============== Evaluating (Gathering) encoder-decoder model on: $dataset (gpu id: $gpu) ==============" echo "======================================================================================================" screen -S eval-"$dataset" -dm bash -c "CUDA_VISIBLE_DEVICES="$gpu" python ../eval.py ../configs/"$dataset"/model/config1.yaml ./logs_"$dataset"_encoder-decoder_1/lightning_logs/checkpoints NA eval-train NA; \ CUDA_VISIBLE_DEVICES="$gpu" python ../eval.py ../configs/"$dataset"/model/config2.yaml ./logs_"$dataset"_encoder-decoder_2/lightning_logs/checkpoints NA eval-train NA; \ CUDA_VISIBLE_DEVICES="$gpu" python ../eval.py ../configs/"$dataset"/model/config3.yaml ./logs_"$dataset"_encoder-decoder_3/lightning_logs/checkpoints NA eval-train NA; \ exec sh"; ================================================ FILE: scripts/encoder_decoder_long_term_model_rollout.sh ================================================ #!/bin/bash dataset=$1 gpu=$2 echo "=======================================================================================================" echo "============== Long-term model rollout encoder-decoder model on: $dataset (gpu id: $gpu) ==============" echo "=======================================================================================================" screen -S eval-"$dataset"-longterm-modelrollout -dm bash -c "CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout ../configs/"$dataset"/model/config1.yaml ./logs_"$dataset"_encoder-decoder_1/lightning_logs/checkpoints NA NA 60; \ CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout ../configs/"$dataset"/model/config2.yaml ./logs_"$dataset"_encoder-decoder_2/lightning_logs/checkpoints NA NA 60; \ CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout ../configs/"$dataset"/model/config3.yaml ./logs_"$dataset"_encoder-decoder_3/lightning_logs/checkpoints NA NA 60; \ exec sh"; ================================================ FILE: scripts/encoder_decoder_long_term_model_rollout_perturb_all.sh ================================================ #!/bin/bash dataset=$1 gpu=$2 echo "=======================================================================================================" echo "============== Long-term model rollout encoder-decoder model on: $dataset (gpu id: $gpu) ==============" echo "=======================================================================================================" screen -S eval-"$dataset"-longterm-modelrollout -dm bash -c "CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-background_replace-8 ../configs/"$dataset"/model/config1.yaml ./logs_"$dataset"_encoder-decoder_1/lightning_logs/checkpoints NA NA 60; \ CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-background_replace-8 ../configs/"$dataset"/model/config2.yaml ./logs_"$dataset"_encoder-decoder_2/lightning_logs/checkpoints NA NA 60; \ CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-background_replace-8 ../configs/"$dataset"/model/config3.yaml ./logs_"$dataset"_encoder-decoder_3/lightning_logs/checkpoints NA NA 60; \ CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-white_noise-5 ../configs/"$dataset"/model/config1.yaml ./logs_"$dataset"_encoder-decoder_1/lightning_logs/checkpoints NA NA 60; \ CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-white_noise-5 ../configs/"$dataset"/model/config2.yaml ./logs_"$dataset"_encoder-decoder_2/lightning_logs/checkpoints NA NA 60; \ CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-white_noise-5 ../configs/"$dataset"/model/config3.yaml ./logs_"$dataset"_encoder-decoder_3/lightning_logs/checkpoints NA NA 60; \ CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-white_noise-7 ../configs/"$dataset"/model/config1.yaml ./logs_"$dataset"_encoder-decoder_1/lightning_logs/checkpoints NA NA 60; \ CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-white_noise-7 ../configs/"$dataset"/model/config2.yaml ./logs_"$dataset"_encoder-decoder_2/lightning_logs/checkpoints NA NA 60; \ CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-white_noise-7 ../configs/"$dataset"/model/config3.yaml ./logs_"$dataset"_encoder-decoder_3/lightning_logs/checkpoints NA NA 60; \ CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-white_noise-8 ../configs/"$dataset"/model/config1.yaml ./logs_"$dataset"_encoder-decoder_1/lightning_logs/checkpoints NA NA 60; \ CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-white_noise-8 ../configs/"$dataset"/model/config2.yaml ./logs_"$dataset"_encoder-decoder_2/lightning_logs/checkpoints NA NA 60; \ CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-white_noise-8 ../configs/"$dataset"/model/config3.yaml ./logs_"$dataset"_encoder-decoder_3/lightning_logs/checkpoints NA NA 60; \ CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-background_cover-5 ../configs/"$dataset"/model/config1.yaml ./logs_"$dataset"_encoder-decoder_1/lightning_logs/checkpoints NA NA 60; \ CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-background_cover-5 ../configs/"$dataset"/model/config2.yaml ./logs_"$dataset"_encoder-decoder_2/lightning_logs/checkpoints NA NA 60; \ CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-background_cover-5 ../configs/"$dataset"/model/config3.yaml ./logs_"$dataset"_encoder-decoder_3/lightning_logs/checkpoints NA NA 60; \ CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-background_cover-7 ../configs/"$dataset"/model/config1.yaml ./logs_"$dataset"_encoder-decoder_1/lightning_logs/checkpoints NA NA 60; \ CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-background_cover-7 ../configs/"$dataset"/model/config2.yaml ./logs_"$dataset"_encoder-decoder_2/lightning_logs/checkpoints NA NA 60; \ CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-background_cover-7 ../configs/"$dataset"/model/config3.yaml ./logs_"$dataset"_encoder-decoder_3/lightning_logs/checkpoints NA NA 60; \ CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-background_cover-8 ../configs/"$dataset"/model/config1.yaml ./logs_"$dataset"_encoder-decoder_1/lightning_logs/checkpoints NA NA 60; \ CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-background_cover-8 ../configs/"$dataset"/model/config2.yaml ./logs_"$dataset"_encoder-decoder_2/lightning_logs/checkpoints NA NA 60; \ CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-background_cover-8 ../configs/"$dataset"/model/config3.yaml ./logs_"$dataset"_encoder-decoder_3/lightning_logs/checkpoints NA NA 60; \ CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-background_replace-5 ../configs/"$dataset"/model/config1.yaml ./logs_"$dataset"_encoder-decoder_1/lightning_logs/checkpoints NA NA 60; \ CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-background_replace-5 ../configs/"$dataset"/model/config2.yaml ./logs_"$dataset"_encoder-decoder_2/lightning_logs/checkpoints NA NA 60; \ CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-background_replace-5 ../configs/"$dataset"/model/config3.yaml ./logs_"$dataset"_encoder-decoder_3/lightning_logs/checkpoints NA NA 60; \ CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-background_replace-7 ../configs/"$dataset"/model/config1.yaml ./logs_"$dataset"_encoder-decoder_1/lightning_logs/checkpoints NA NA 60; \ CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-background_replace-7 ../configs/"$dataset"/model/config2.yaml ./logs_"$dataset"_encoder-decoder_2/lightning_logs/checkpoints NA NA 60; \ CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-background_replace-7 ../configs/"$dataset"/model/config3.yaml ./logs_"$dataset"_encoder-decoder_3/lightning_logs/checkpoints NA NA 60; \ exec sh"; ================================================ FILE: scripts/encoder_decoder_train.sh ================================================ #!/bin/bash dataset=$1 gpu=$2 echo "========================================================================================" echo "============== Training encoder-decoder model on: $dataset (gpu id: $gpu) ==============" echo "========================================================================================" screen -S train-"$dataset" -dm bash -c "CUDA_VISIBLE_DEVICES="$gpu" python ../main.py ../configs/"$dataset"/model/config1.yaml; \ CUDA_VISIBLE_DEVICES="$gpu" python ../main.py ../configs/"$dataset"/model/config2.yaml; \ CUDA_VISIBLE_DEVICES="$gpu" python ../main.py ../configs/"$dataset"/model/config3.yaml; \ exec sh"; ================================================ FILE: scripts/latentpred_train.sh ================================================ #!/bin/bash dataset=$1 gpu=$2 echo "========================================================================================" echo "================ Training latentpred model on: $dataset (gpu id: $gpu) =================" echo "========================================================================================" screen -S train-"$dataset" -dm bash -c "CUDA_VISIBLE_DEVICES="$gpu" python ../main.py latentpred ../configs/"$dataset"/latentpred/config1.yaml ./logs_"$dataset"_encoder-decoder-64_1/lightning_logs/checkpoints/ ./logs_"$dataset"_refine-64_1/lightning_logs/checkpoints/; \ CUDA_VISIBLE_DEVICES="$gpu" python ../main.py latentpred ../configs/"$dataset"/latentpred/config2.yaml ./logs_"$dataset"_encoder-decoder-64_2/lightning_logs/checkpoints/ ./logs_"$dataset"_refine-64_2/lightning_logs/checkpoints/; \ CUDA_VISIBLE_DEVICES="$gpu" python ../main.py latentpred ../configs/"$dataset"/latentpred/config3.yaml ./logs_"$dataset"_encoder-decoder-64_3/lightning_logs/checkpoints/ ./logs_"$dataset"_refine-64_3/lightning_logs/checkpoints/; \ exec sh"; ================================================ FILE: scripts/long_term_eval_stability.sh ================================================ #!/bin/bash dataset=$1 gpu=$2 echo "============================================================================================" echo "========= Evaluating stability of long term prediction on: $dataset (gpu id: $gpu) =========" echo "============================================================================================" screen -S eval-"$dataset"-long-term-stab -dm bash -c "CUDA_VISIBLE_DEVICES="$gpu" python ../stability.py ../configs/"$dataset"/latentpred/config1.yaml ./logs_"$dataset"_latent-prediction_1/lightning_logs/checkpoints/ ./logs_"$dataset"_encoder-decoder-64_1/lightning_logs/checkpoints/ ./logs_"$dataset"_refine-64_1/lightning_logs/checkpoints/ ./logs_"$dataset"_encoder-decoder_1/prediction_long_term/model_rollout/ 60; \ CUDA_VISIBLE_DEVICES="$gpu" python ../stability.py ../configs/"$dataset"/latentpred/config1.yaml ./logs_"$dataset"_latent-prediction_1/lightning_logs/checkpoints/ ./logs_"$dataset"_encoder-decoder-64_1/lightning_logs/checkpoints/ ./logs_"$dataset"_refine-64_1/lightning_logs/checkpoints/ ./logs_"$dataset"_encoder-decoder-64_1/prediction_long_term/model_rollout/ 60; \ CUDA_VISIBLE_DEVICES="$gpu" python ../stability.py ../configs/"$dataset"/latentpred/config1.yaml ./logs_"$dataset"_latent-prediction_1/lightning_logs/checkpoints/ ./logs_"$dataset"_encoder-decoder-64_1/lightning_logs/checkpoints/ ./logs_"$dataset"_refine-64_1/lightning_logs/checkpoints/ ./logs_"$dataset"_refine-64_1/prediction_long_term/model_rollout/ 60; \ CUDA_VISIBLE_DEVICES="$gpu" python ../stability.py ../configs/"$dataset"/latentpred/config2.yaml ./logs_"$dataset"_latent-prediction_2/lightning_logs/checkpoints/ ./logs_"$dataset"_encoder-decoder-64_2/lightning_logs/checkpoints/ ./logs_"$dataset"_refine-64_2/lightning_logs/checkpoints/ ./logs_"$dataset"_encoder-decoder_2/prediction_long_term/model_rollout/ 60; \ CUDA_VISIBLE_DEVICES="$gpu" python ../stability.py ../configs/"$dataset"/latentpred/config2.yaml ./logs_"$dataset"_latent-prediction_2/lightning_logs/checkpoints/ ./logs_"$dataset"_encoder-decoder-64_2/lightning_logs/checkpoints/ ./logs_"$dataset"_refine-64_2/lightning_logs/checkpoints/ ./logs_"$dataset"_encoder-decoder-64_2/prediction_long_term/model_rollout/ 60; \ CUDA_VISIBLE_DEVICES="$gpu" python ../stability.py ../configs/"$dataset"/latentpred/config2.yaml ./logs_"$dataset"_latent-prediction_2/lightning_logs/checkpoints/ ./logs_"$dataset"_encoder-decoder-64_2/lightning_logs/checkpoints/ ./logs_"$dataset"_refine-64_2/lightning_logs/checkpoints/ ./logs_"$dataset"_refine-64_2/prediction_long_term/model_rollout/ 60; \ CUDA_VISIBLE_DEVICES="$gpu" python ../stability.py ../configs/"$dataset"/latentpred/config3.yaml ./logs_"$dataset"_latent-prediction_3/lightning_logs/checkpoints/ ./logs_"$dataset"_encoder-decoder-64_3/lightning_logs/checkpoints/ ./logs_"$dataset"_refine-64_3/lightning_logs/checkpoints/ ./logs_"$dataset"_encoder-decoder_3/prediction_long_term/model_rollout/ 60; \ CUDA_VISIBLE_DEVICES="$gpu" python ../stability.py ../configs/"$dataset"/latentpred/config3.yaml ./logs_"$dataset"_latent-prediction_3/lightning_logs/checkpoints/ ./logs_"$dataset"_encoder-decoder-64_3/lightning_logs/checkpoints/ ./logs_"$dataset"_refine-64_3/lightning_logs/checkpoints/ ./logs_"$dataset"_encoder-decoder-64_3/prediction_long_term/model_rollout/ 60; \ CUDA_VISIBLE_DEVICES="$gpu" python ../stability.py ../configs/"$dataset"/latentpred/config3.yaml ./logs_"$dataset"_latent-prediction_3/lightning_logs/checkpoints/ ./logs_"$dataset"_encoder-decoder-64_3/lightning_logs/checkpoints/ ./logs_"$dataset"_refine-64_3/lightning_logs/checkpoints/ ./logs_"$dataset"_refine-64_3/prediction_long_term/model_rollout/ 60; \ exec sh"; ================================================ FILE: scripts/long_term_eval_stability_hybrid.sh ================================================ #!/bin/bash dataset=$1 gpu=$2 step=$3 echo "============================================================================================" echo "========= Evaluating stability of long term prediction on: $dataset (gpu id: $gpu) =========" echo "============================================================================================" screen -S eval-"$dataset"-long-term-stab -dm bash -c "CUDA_VISIBLE_DEVICES="$gpu" python ../stability.py ../configs/"$dataset"/latentpred/config1.yaml ./logs_"$dataset"_latent-prediction_1/lightning_logs/checkpoints/ ./logs_"$dataset"_encoder-decoder-64_1/lightning_logs/checkpoints/ ./logs_"$dataset"_refine-64_1/lightning_logs/checkpoints/ ./logs_"$dataset"_refine-64_1/prediction_long_term/hybrid_rollout_$step/ 60; \ CUDA_VISIBLE_DEVICES="$gpu" python ../stability.py ../configs/"$dataset"/latentpred/config2.yaml ./logs_"$dataset"_latent-prediction_2/lightning_logs/checkpoints/ ./logs_"$dataset"_encoder-decoder-64_2/lightning_logs/checkpoints/ ./logs_"$dataset"_refine-64_2/lightning_logs/checkpoints/ ./logs_"$dataset"_refine-64_2/prediction_long_term/hybrid_rollout_$step/ 60; \ CUDA_VISIBLE_DEVICES="$gpu" python ../stability.py ../configs/"$dataset"/latentpred/config3.yaml ./logs_"$dataset"_latent-prediction_3/lightning_logs/checkpoints/ ./logs_"$dataset"_encoder-decoder-64_3/lightning_logs/checkpoints/ ./logs_"$dataset"_refine-64_3/lightning_logs/checkpoints/ ./logs_"$dataset"_refine-64_3/prediction_long_term/hybrid_rollout_$step/ 60; \ exec sh"; ================================================ FILE: scripts/refine_64_eval.sh ================================================ #!/bin/bash dataset=$1 gpu=$2 echo "====================================================================================" echo "============== Evaluating refine-64 model on: $dataset (gpu id: $gpu) ==============" echo "====================================================================================" screen -S eval-"$dataset"-refine-64 -dm bash -c "CUDA_VISIBLE_DEVICES="$gpu" python ../eval.py ../configs/"$dataset"/refine64/config1.yaml ./logs_"$dataset"_refine-64_1/lightning_logs/checkpoints ./logs_"$dataset"_encoder-decoder-64_1/lightning_logs/checkpoints eval-eval NA; \ CUDA_VISIBLE_DEVICES="$gpu" python ../eval.py ../configs/"$dataset"/refine64/config2.yaml ./logs_"$dataset"_refine-64_2/lightning_logs/checkpoints ./logs_"$dataset"_encoder-decoder-64_2/lightning_logs/checkpoints eval-eval NA; \ CUDA_VISIBLE_DEVICES="$gpu" python ../eval.py ../configs/"$dataset"/refine64/config3.yaml ./logs_"$dataset"_refine-64_3/lightning_logs/checkpoints ./logs_"$dataset"_encoder-decoder-64_3/lightning_logs/checkpoints eval-eval NA; \ exec sh"; ================================================ FILE: scripts/refine_64_eval_gather.sh ================================================ #!/bin/bash dataset=$1 gpu=$2 echo "================================================================================================" echo "============== Evaluating (Gathering) refine-64 model on: $dataset (gpu id: $gpu) ==============" echo "================================================================================================" screen -S eval-"$dataset"-refine-64 -dm bash -c "CUDA_VISIBLE_DEVICES="$gpu" python ../eval.py ../configs/"$dataset"/refine64/config1.yaml ./logs_"$dataset"_refine-64_1/lightning_logs/checkpoints ./logs_"$dataset"_encoder-decoder-64_1/lightning_logs/checkpoints eval-refine-train NA; \ CUDA_VISIBLE_DEVICES="$gpu" python ../eval.py ../configs/"$dataset"/refine64/config2.yaml ./logs_"$dataset"_refine-64_2/lightning_logs/checkpoints ./logs_"$dataset"_encoder-decoder-64_2/lightning_logs/checkpoints eval-refine-train NA; \ CUDA_VISIBLE_DEVICES="$gpu" python ../eval.py ../configs/"$dataset"/refine64/config3.yaml ./logs_"$dataset"_refine-64_3/lightning_logs/checkpoints ./logs_"$dataset"_encoder-decoder-64_3/lightning_logs/checkpoints eval-refine-train NA; \ exec sh"; ================================================ FILE: scripts/refine_64_long_term_hybrid_rollout.sh ================================================ #!/bin/bash dataset=$1 gpu=$2 step=$3 echo "================================================================================================" echo "======= Long-term hybrid-$step model rollout refine-64 model on: $dataset (gpu id: $gpu) =======" echo "================================================================================================" screen -S eval-"$dataset"-longterm-hybridrollout -dm bash -c "CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py hybrid-$step ../configs/"$dataset"/refine64/config1.yaml ./logs_"$dataset"_encoder-decoder-64_1/lightning_logs/checkpoints ./logs_"$dataset"_refine-64_1/lightning_logs/checkpoints NA 60; \ CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py hybrid-$step ../configs/"$dataset"/refine64/config2.yaml ./logs_"$dataset"_encoder-decoder-64_2/lightning_logs/checkpoints ./logs_"$dataset"_refine-64_2/lightning_logs/checkpoints NA 60; \ CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py hybrid-$step ../configs/"$dataset"/refine64/config3.yaml ./logs_"$dataset"_encoder-decoder-64_3/lightning_logs/checkpoints ./logs_"$dataset"_refine-64_3/lightning_logs/checkpoints NA 60; \ exec sh"; ================================================ FILE: scripts/refine_64_long_term_model_rollout.sh ================================================ #!/bin/bash dataset=$1 gpu=$2 echo "=================================================================================================" echo "============== Long-term model rollout refine-64 model on: $dataset (gpu id: $gpu) ==============" echo "=================================================================================================" screen -S eval-"$dataset"-longterm-modelrollout -dm bash -c "CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout ../configs/"$dataset"/refine64/config1.yaml ./logs_"$dataset"_encoder-decoder-64_1/lightning_logs/checkpoints ./logs_"$dataset"_refine-64_1/lightning_logs/checkpoints NA 60; \ CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout ../configs/"$dataset"/refine64/config2.yaml ./logs_"$dataset"_encoder-decoder-64_2/lightning_logs/checkpoints ./logs_"$dataset"_refine-64_2/lightning_logs/checkpoints NA 60; \ CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout ../configs/"$dataset"/refine64/config3.yaml ./logs_"$dataset"_encoder-decoder-64_3/lightning_logs/checkpoints ./logs_"$dataset"_refine-64_3/lightning_logs/checkpoints NA 60; \ exec sh"; ================================================ FILE: scripts/refine_64_long_term_model_rollout_perturb_all.sh ================================================ #!/bin/bash dataset=$1 gpu=$2 echo "=================================================================================================" echo "============== Long-term model rollout refine-64 model on: $dataset (gpu id: $gpu) ==============" echo "=================================================================================================" screen -S eval-"$dataset"-longterm-modelrollout -dm bash -c "CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-background_cover-5 ../configs/"$dataset"/refine64/config1.yaml ./logs_"$dataset"_encoder-decoder-64_1/lightning_logs/checkpoints ./logs_"$dataset"_refine-64_1/lightning_logs/checkpoints NA 60; \ CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-background_cover-5 ../configs/"$dataset"/refine64/config2.yaml ./logs_"$dataset"_encoder-decoder-64_2/lightning_logs/checkpoints ./logs_"$dataset"_refine-64_2/lightning_logs/checkpoints NA 60; \ CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-background_cover-5 ../configs/"$dataset"/refine64/config3.yaml ./logs_"$dataset"_encoder-decoder-64_3/lightning_logs/checkpoints ./logs_"$dataset"_refine-64_3/lightning_logs/checkpoints NA 60; \ CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-background_cover-7 ../configs/"$dataset"/refine64/config1.yaml ./logs_"$dataset"_encoder-decoder-64_1/lightning_logs/checkpoints ./logs_"$dataset"_refine-64_1/lightning_logs/checkpoints NA 60; \ CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-background_cover-7 ../configs/"$dataset"/refine64/config2.yaml ./logs_"$dataset"_encoder-decoder-64_2/lightning_logs/checkpoints ./logs_"$dataset"_refine-64_2/lightning_logs/checkpoints NA 60; \ CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-background_cover-7 ../configs/"$dataset"/refine64/config3.yaml ./logs_"$dataset"_encoder-decoder-64_3/lightning_logs/checkpoints ./logs_"$dataset"_refine-64_3/lightning_logs/checkpoints NA 60; \ CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-background_cover-8 ../configs/"$dataset"/refine64/config1.yaml ./logs_"$dataset"_encoder-decoder-64_1/lightning_logs/checkpoints ./logs_"$dataset"_refine-64_1/lightning_logs/checkpoints NA 60; \ CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-background_cover-8 ../configs/"$dataset"/refine64/config2.yaml ./logs_"$dataset"_encoder-decoder-64_2/lightning_logs/checkpoints ./logs_"$dataset"_refine-64_2/lightning_logs/checkpoints NA 60; \ CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-background_cover-8 ../configs/"$dataset"/refine64/config3.yaml ./logs_"$dataset"_encoder-decoder-64_3/lightning_logs/checkpoints ./logs_"$dataset"_refine-64_3/lightning_logs/checkpoints NA 60; \ CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-background_replace-5 ../configs/"$dataset"/refine64/config1.yaml ./logs_"$dataset"_encoder-decoder-64_1/lightning_logs/checkpoints ./logs_"$dataset"_refine-64_1/lightning_logs/checkpoints NA 60; \ CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-background_replace-5 ../configs/"$dataset"/refine64/config2.yaml ./logs_"$dataset"_encoder-decoder-64_2/lightning_logs/checkpoints ./logs_"$dataset"_refine-64_2/lightning_logs/checkpoints NA 60; \ CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-background_replace-5 ../configs/"$dataset"/refine64/config3.yaml ./logs_"$dataset"_encoder-decoder-64_3/lightning_logs/checkpoints ./logs_"$dataset"_refine-64_3/lightning_logs/checkpoints NA 60; \ CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-background_replace-7 ../configs/"$dataset"/refine64/config1.yaml ./logs_"$dataset"_encoder-decoder-64_1/lightning_logs/checkpoints ./logs_"$dataset"_refine-64_1/lightning_logs/checkpoints NA 60; \ CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-background_replace-7 ../configs/"$dataset"/refine64/config2.yaml ./logs_"$dataset"_encoder-decoder-64_2/lightning_logs/checkpoints ./logs_"$dataset"_refine-64_2/lightning_logs/checkpoints NA 60; \ CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-background_replace-7 ../configs/"$dataset"/refine64/config3.yaml ./logs_"$dataset"_encoder-decoder-64_3/lightning_logs/checkpoints ./logs_"$dataset"_refine-64_3/lightning_logs/checkpoints NA 60; \ CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-background_replace-8 ../configs/"$dataset"/refine64/config1.yaml ./logs_"$dataset"_encoder-decoder-64_1/lightning_logs/checkpoints ./logs_"$dataset"_refine-64_1/lightning_logs/checkpoints NA 60; \ CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-background_replace-8 ../configs/"$dataset"/refine64/config2.yaml ./logs_"$dataset"_encoder-decoder-64_2/lightning_logs/checkpoints ./logs_"$dataset"_refine-64_2/lightning_logs/checkpoints NA 60; \ CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-background_replace-8 ../configs/"$dataset"/refine64/config3.yaml ./logs_"$dataset"_encoder-decoder-64_3/lightning_logs/checkpoints ./logs_"$dataset"_refine-64_3/lightning_logs/checkpoints NA 60; \ CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-white_noise-5 ../configs/"$dataset"/refine64/config1.yaml ./logs_"$dataset"_encoder-decoder-64_1/lightning_logs/checkpoints ./logs_"$dataset"_refine-64_1/lightning_logs/checkpoints NA 60; \ CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-white_noise-5 ../configs/"$dataset"/refine64/config2.yaml ./logs_"$dataset"_encoder-decoder-64_2/lightning_logs/checkpoints ./logs_"$dataset"_refine-64_2/lightning_logs/checkpoints NA 60; \ CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-white_noise-5 ../configs/"$dataset"/refine64/config3.yaml ./logs_"$dataset"_encoder-decoder-64_3/lightning_logs/checkpoints ./logs_"$dataset"_refine-64_3/lightning_logs/checkpoints NA 60; \ CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-white_noise-7 ../configs/"$dataset"/refine64/config1.yaml ./logs_"$dataset"_encoder-decoder-64_1/lightning_logs/checkpoints ./logs_"$dataset"_refine-64_1/lightning_logs/checkpoints NA 60; \ CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-white_noise-7 ../configs/"$dataset"/refine64/config2.yaml ./logs_"$dataset"_encoder-decoder-64_2/lightning_logs/checkpoints ./logs_"$dataset"_refine-64_2/lightning_logs/checkpoints NA 60; \ CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-white_noise-7 ../configs/"$dataset"/refine64/config3.yaml ./logs_"$dataset"_encoder-decoder-64_3/lightning_logs/checkpoints ./logs_"$dataset"_refine-64_3/lightning_logs/checkpoints NA 60; \ CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-white_noise-8 ../configs/"$dataset"/refine64/config1.yaml ./logs_"$dataset"_encoder-decoder-64_1/lightning_logs/checkpoints ./logs_"$dataset"_refine-64_1/lightning_logs/checkpoints NA 60; \ CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-white_noise-8 ../configs/"$dataset"/refine64/config2.yaml ./logs_"$dataset"_encoder-decoder-64_2/lightning_logs/checkpoints ./logs_"$dataset"_refine-64_2/lightning_logs/checkpoints NA 60; \ CUDA_VISIBLE_DEVICES="$gpu" python ../pred.py model-rollout-perturb-white_noise-8 ../configs/"$dataset"/refine64/config3.yaml ./logs_"$dataset"_encoder-decoder-64_3/lightning_logs/checkpoints ./logs_"$dataset"_refine-64_3/lightning_logs/checkpoints NA 60; \ exec sh"; ================================================ FILE: scripts/refine_64_train.sh ================================================ #!/bin/bash dataset=$1 gpu=$2 echo "===================================================================================" echo "============== Training refine-64 model on: $dataset (gpu id: $gpu) ===============" echo "===================================================================================" screen -S train-"$dataset"-refine-64 -dm bash -c "CUDA_VISIBLE_DEVICES="$gpu" python ../main.py ../configs/"$dataset"/refine64/config1.yaml; \ CUDA_VISIBLE_DEVICES="$gpu" python ../main.py ../configs/"$dataset"/refine64/config2.yaml; \ CUDA_VISIBLE_DEVICES="$gpu" python ../main.py ../configs/"$dataset"/refine64/config3.yaml; \ exec sh"; ================================================ FILE: stability.py ================================================ import os import sys import glob import yaml import json import torch import pprint import shutil import numpy as np from PIL import Image from tqdm import tqdm from munch import munchify from models_latentpred import VisLatentDynamicsModel from dataset import NeuralPhysDataset from pytorch_lightning import Trainer, seed_everything from pytorch_lightning.loggers import TensorBoardLogger def mkdir(folder): if os.path.exists(folder): shutil.rmtree(folder) os.makedirs(folder) def load_config(filepath): with open(filepath, 'r') as stream: try: trainer_params = yaml.safe_load(stream) return trainer_params except yaml.YAMLError as exc: print(exc) def seed(cfg): torch.manual_seed(cfg.seed) if cfg.if_cuda: torch.cuda.manual_seed(cfg.seed) def get_data(filepath): data = Image.open(filepath) data = data.resize((128, 128)) data = np.array(data) data = torch.tensor(data / 255.0) data = data.permute(2, 0, 1).float() return data def main(): config_filepath = str(sys.argv[1]) checkpoint_filepath = str(sys.argv[2]) high_dim_checkpoint_filepath = str(sys.argv[3]) refine_checkpoint_filepath = str(sys.argv[4]) pred_save_path = str(sys.argv[5]) pred_len = int(sys.argv[6]) cfg = load_config(filepath=config_filepath) pprint.pprint(cfg) cfg = munchify(cfg) seed(cfg) seed_everything(cfg.seed) log_dir = '_'.join([cfg.log_dir, cfg.dataset, cfg.model_name, str(cfg.seed)]) model = VisLatentDynamicsModel(lr=cfg.lr, seed=cfg.seed, if_cuda=cfg.if_cuda, if_test=True, gamma=cfg.gamma, log_dir=log_dir, train_batch=cfg.train_batch, val_batch=cfg.val_batch, test_batch=cfg.test_batch, num_workers=cfg.num_workers, model_name=cfg.model_name, data_filepath=cfg.data_filepath, dataset=cfg.dataset, lr_schedule=cfg.lr_schedule) model.load_model(checkpoint_filepath) model.load_high_dim_refine_model(high_dim_checkpoint_filepath, refine_checkpoint_filepath) model.extract_decoder_from_refine_model() model = model.to('cuda') model.eval() model.freeze() # get all the test video ids data_filepath_base = os.path.join(cfg.data_filepath, cfg.dataset) with open(os.path.join('../datainfo', cfg.dataset, f'data_split_dict_{cfg.seed}.json'), 'r') as file: seq_dict = json.load(file) test_vid_ids = seq_dict['test'] errors = [] for p_vid_idx in tqdm(test_vid_ids): data_vid_filepath = os.path.join(data_filepath_base, str(p_vid_idx)) pred_vid_filepath = os.path.join(pred_save_path, str(p_vid_idx)) total_num_frames = len(os.listdir(data_vid_filepath)) suf = os.listdir(data_vid_filepath)[0].split('.')[-1] error_p = [] for start_frame_idx in range(total_num_frames - 3): if start_frame_idx % 2 != 0: continue # get the data if start_frame_idx % pred_len == 0: data = [get_data(os.path.join(data_vid_filepath, f'{start_frame_idx}.{suf}')), get_data(os.path.join(data_vid_filepath, f'{start_frame_idx+1}.{suf}'))] data = (torch.cat(data, 2)).unsqueeze(0).cuda() else: data = [get_data(os.path.join(pred_vid_filepath, f'{start_frame_idx}.{suf}')), get_data(os.path.join(pred_vid_filepath, f'{start_frame_idx+1}.{suf}'))] data = (torch.cat(data, 2)).unsqueeze(0).cuda() # get the target target = [get_data(os.path.join(pred_vid_filepath, f'{start_frame_idx+2}.{suf}')), get_data(os.path.join(pred_vid_filepath, f'{start_frame_idx+3}.{suf}'))] target = (torch.cat(target, 2)).unsqueeze(0).cuda() # compute latent space error data_state = model.data_to_state(data) target_state = model.data_to_state(target) output_state = model.model(data_state) error_p.append(float(model.loss_func(output_state, target_state).cpu().detach().numpy())) errors.append(error_p) errors = np.array(errors) np.save(os.path.join(pred_save_path, 'stability.npy'), errors) if __name__ == '__main__': main() ================================================ FILE: utils/common.py ================================================ import os import shutil import json import random import numpy as np from PIL import Image import plotly.graph_objects as go def mkdir(folder): if os.path.exists(folder): shutil.rmtree(folder) os.makedirs(folder) def create_random_data_splits(seed, data_filepath, object_name, ratio=0.8): random.seed(seed) seq_dict = {} obj_filepath = os.path.join(data_filepath, object_name) num_vids = int(len(os.listdir(obj_filepath)) - 1) vid_id_lst = list(range(num_vids)) random.shuffle(vid_id_lst) # test start = int(num_vids * (ratio + (1 - ratio) / 2)) seq_dict['test'] = vid_id_lst[start:] # val start = int(num_vids * ratio) end = int(num_vids * (ratio + (1 - ratio) / 2)) seq_dict['val'] = vid_id_lst[start:end] # train seq_dict['train'] = vid_id_lst[:int(num_vids * ratio)] with open(os.path.join(f'./datainfo/{object_name}/data_split_dict_{seed}.json'), 'w') as file: json.dump(seq_dict, file, indent=4) def separate_eval_image(filepath): img = Image.open(filepath) img.crop((2, 2, 130, 130)).save('input_0.png') img.crop((2, 132, 130, 260)).save('input_1.png') img.crop((2, 262, 130, 390)).save('truth_0.png') img.crop((2, 392, 130, 520)).save('truth_1.png') img.crop((2, 522, 130, 650)).save('output_0.png') img.crop((2, 652, 130, 780)).save('output_1.png') def collect_long_term_pred_sample(dataset, seed, test_vid_id, hybrid_step, start=0, step=12, stop=60, suf='png'): idv = id_value[dataset] fps = fps_value[dataset] save_path = f'long_term_pred_sample_{dataset}_seed_{seed}_vid_{test_vid_id}' mkdir(save_path) data_filepath = f'/data/kuang/visphysics_data/{dataset}/{test_vid_id}' pred_filepaths = {'dim-8192':f'/data/kuang/logs/logs_{dataset}_encoder-decoder_{seed}/prediction_long_term/model_rollout/{test_vid_id}', 'dim-64':f'/data/kuang/logs/logs_{dataset}_encoder-decoder-64_{seed}/prediction_long_term/model_rollout/{test_vid_id}', f'dim-{idv}':f'/data/kuang/logs/logs_{dataset}_refine-64_{seed}/prediction_long_term/model_rollout/{test_vid_id}', f'hybrid-64-{idv}':f'/data/kuang/logs/logs_{dataset}_refine-64_{seed}/prediction_long_term/hybrid_rollout_{hybrid_step}/{test_vid_id}'} Image.open(os.path.join(data_filepath, str(start)+'.'+suf)).save(os.path.join(save_path, 'input_0.png')) Image.open(os.path.join(data_filepath, str(start+1)+'.'+suf)).save(os.path.join(save_path, 'input_1.png')) for k in range(start+step, stop+1, step): data = Image.open(os.path.join(data_filepath, str(k-1)+'.'+suf)) data.save(os.path.join(save_path, 'truth_%.1fs.png'%((k-start)/fps))) for scheme in ['dim-8192', 'dim-64', f'dim-{idv}', f'hybrid-64-{idv}']: pred = Image.open(os.path.join(pred_filepaths[scheme], str(k-1)+'.'+suf)) pred.save(os.path.join(save_path, scheme+'_%.1fs.png'%((k-start)/fps))) """ physical system parameters """ id_value = {'circular_motion':2, 'reaction_diffusion':2, 'single_pendulum':2, 'double_pendulum':4, 'swingstick_non_magnetic':4, 'elastic_pendulum':6, 'air_dancer':8, 'lava_lamp':8, 'fire':24} fps_value = {'circular_motion':60, 'reaction_diffusion':5, 'single_pendulum':60, 'double_pendulum':60, 'swingstick_non_magnetic':60, 'elastic_pendulum':60, 'air_dancer':60, 'lava_lamp':2.5, 'fire':24} """ plot style """ cols = ['#df1e1e', '#f0975a', '#1b66cb', '#149650'] colorscale=[[0.0, "rgb(49,54,149)"], [0.1111111111111111, "rgb(69,117,180)"], [0.2222222222222222, "rgb(116,173,209)"], [0.3333333333333333, "rgb(171,217,233)"], [0.4444444444444444, "rgb(224,243,248)"], [0.5555555555555556, "rgb(254,224,144)"], [0.6666666666666666, "rgb(253,174,97)"], [0.7777777777777778, "rgb(244,109,67)"], [0.8888888888888888, "rgb(215,48,39)"], [1.0, "rgb(165,0,38)"]] def light_color(col): if col[:3] == 'rgb': rgb = cols[n][4:-1] elif col[0] == '#': rgb = str(tuple(int(col.lstrip('#')[i:i+2],16) for i in (0,2,4)))[1:-1] return 'rgba('+rgb +',0.2)' def update_figure(fig): fig.update_xaxes(showgrid=False, tickfont=dict(family="sans serif", size=18, color="black")) fig.update_yaxes(showgrid=False, tickfont=dict(family="sans serif", size=18, color="black")) fig.update_xaxes(showline=True, linewidth=2, linecolor="black") fig.update_yaxes(showline=True, linewidth=2, linecolor="black") fig.update_layout(legend=go.layout.Legend(traceorder="normal", font=dict(family="sans serif", size=18, color="black"), )) fig.update_layout(font=dict(family="sans serif", size=24, color="black"), paper_bgcolor='rgba(0,0,0,0)', plot_bgcolor='rgba(0,0,0,0)', ) def test(): fig = go.Figure() x = np.arange(30) for k in range(4): fig.add_trace(go.Scatter(x=x, y=x*(k+1), mode='lines', line=dict(width=4, color=cols[k]))) update_figure(fig) fig.write_image('test.png', scale=4) """ test code only """ if __name__ == '__main__': test() ================================================ FILE: utils/dimension.py ================================================ import numpy as np from scipy import stats import os import sys import yaml from munch import munchify def load_config(filepath): with open(filepath, 'r') as stream: try: trainer_params = yaml.safe_load(stream) return trainer_params except yaml.YAMLError as exc: print(exc) def calculate_intrinsic_dimension_statistics(dataset, model_type='model', if_image=False): configs_dir = os.path.join('../configs', dataset, model_type) if if_image: filename = 'intrinsic_dimension_image.npy' else: filename = 'intrinsic_dimension.npy' dims_all = [] for config_filepath in os.listdir(configs_dir): cfg = load_config(filepath=os.path.join(configs_dir, config_filepath)) cfg = munchify(cfg) log_dir = '_'.join([cfg.log_dir, cfg.dataset, cfg.model_name, str(cfg.seed)]) vars_filepath = os.path.join(log_dir, 'variables') dims = np.load(os.path.join(vars_filepath, filename)) dims_all.append(dims) dims_all = np.concatenate(dims_all) dim_mean = np.mean(dims_all) dim_std = np.std(dims_all) print('Mean (±std):', '%.2f (±%.2f)' % (dim_mean, dim_std)) print('Confidence interval:', '(%.1f, %.1f)' % (dim_mean-1.96*dim_std, dim_mean+1.96*dim_std)) def calculate_intrinsic_dimension_statistics_all_methods(dataset, model_type='model', if_image=False): configs_dir = os.path.join('../configs', dataset, model_type) if if_image: filename = 'intrinsic_dimension_image_all_methods.npy' else: filename = 'intrinsic_dimension_all_methods.npy' all_methods = ['Levina_Bickel', 'MiND_ML', 'MiND_KL', 'Hein', 'CD'] dims_all = {method:[] for method in all_methods} for config_filepath in os.listdir(configs_dir): cfg = load_config(filepath=os.path.join(configs_dir, config_filepath)) cfg = munchify(cfg) log_dir = '_'.join([cfg.log_dir, cfg.dataset, cfg.model_name, str(cfg.seed)]) vars_filepath = os.path.join(log_dir, 'variables') dims = np.load(os.path.join(vars_filepath, filename), allow_pickle=True).item() for method in all_methods: dims_all[method].append(dims[method]) for method in all_methods: if method not in ['Hein', 'CD']: dims_all[method] = np.concatenate(dims_all[method]) dim_mean = np.mean(dims_all[method]) dim_std = np.std(dims_all[method]) print(method + ':', '%.2f (±%.2f)' % (dim_mean, dim_std)) if __name__ == '__main__': dataset = str(sys.argv[1]) calculate_intrinsic_dimension_statistics(dataset, 'model') # calculate_intrinsic_dimension_statistics(dataset, 'model', if_image=True) # calculate_intrinsic_dimension_statistics_all_methods(dataset, 'model')